| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for sharding util ops (XlaSplitND, XlaConcatND).""" |
| |
| from typing import Any, List, Optional |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.client.session import Session |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework.ops import control_dependencies |
| from tensorflow.python.framework.ops import Tensor |
| from tensorflow.python.ops import gen_tpu_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| |
| |
| def create_tensor_split_graph( |
| sess: Session, |
| input_value: Any, |
| input_dtype: Any, |
| num_outputs: int, |
| num_splits: List[int], |
| paddings: Optional[List[int]] = None) -> List[Tensor]: |
| del sess |
| |
| const_input_op = constant_op.constant(input_value, dtype=input_dtype) |
| return gen_tpu_ops.xla_split_nd( |
| const_input_op, num_outputs, num_splits, paddings=paddings) |
| |
| |
| def create_resource_split_graph( |
| sess: Session, |
| input_value: Any, |
| input_dtype: Any, |
| num_outputs: int, |
| num_splits: List[int], |
| paddings: Optional[List[int]] = None) -> List[Tensor]: |
| variable = resource_variable_ops.ResourceVariable( |
| initial_value=input_value, dtype=input_dtype) |
| sess.run(variables.variables_initializer([variable])) |
| return gen_tpu_ops.read_variable_xla_split_nd( |
| variable.handle, input_dtype, num_outputs, num_splits, paddings=paddings) |
| |
| |
| class XlaSplitNDOpTest(xla_test.XLATestCase, parameterized.TestCase): |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testSplitDimensionZero(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[[0]]], |
| input_dtype=dtype, |
| num_outputs=1, |
| num_splits=[1, 1, 0]) |
| with self.assertRaisesOpError('index 2 must be positive, but got 0'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testSplitDimensionNegative(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[[0]]], |
| input_dtype=dtype, |
| num_outputs=1, |
| num_splits=[1, -1, 1]) |
| with self.assertRaisesOpError('index 1 must be positive, but got -1'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testNumOutputsMismatch(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[0, 1], |
| input_dtype=dtype, |
| num_outputs=1, |
| num_splits=[2]) |
| with self.assertRaisesOpError('\'N\' must match number of slices 2'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testPaddingsLengthMismatch(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[0, 1], [2, 3]], |
| input_dtype=dtype, |
| num_outputs=4, |
| num_splits=[2, 2], |
| paddings=[0]) |
| with self.assertRaisesOpError('length 2, but got 1'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testPaddingsNegative(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[0, 1], [2, 3]], |
| input_dtype=dtype, |
| num_outputs=4, |
| num_splits=[2, 2], |
| paddings=[0, -1]) |
| with self.assertRaisesOpError('non-negative, but got -1 at index 1'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testInputRankSplitMismatch(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[0, 1], [2, 3]], |
| input_dtype=dtype, |
| num_outputs=8, |
| num_splits=[2, 2, 2]) |
| with self.assertRaisesOpError( |
| '\'num_splits\' length 3, but got rank 2'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testDimNotEvenlySplit(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[0, 1], [2, 3], [4, 5], [6, 7]], |
| input_dtype=dtype, |
| num_outputs=6, |
| num_splits=[3, 2]) |
| with self.assertRaisesOpError('divisible by \'num_splits\' 3'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testDimWithPaddingNotEvenlySplit(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[0, 1], [2, 3], [4, 5], [6, 7]], |
| input_dtype=dtype, |
| num_outputs=4, |
| num_splits=[2, 2], |
| paddings=[0, 1]) |
| with self.assertRaisesOpError('divisible by \'num_splits\' 2'): |
| sess.run(split) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testNoSplits(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], |
| input_dtype=dtype, |
| num_outputs=1, |
| num_splits=[1, 1, 1]) |
| results = sess.run(split) |
| self.assertLen(results, 1) |
| self.assertAllClose(results[0], [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testNoSplitsWithPadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[[0]], [[1]]], |
| input_dtype=dtype, |
| num_outputs=1, |
| num_splits=[1, 1, 1], |
| paddings=[0, 1, 1]) |
| results = sess.run(split) |
| self.assertLen(results, 1) |
| self.assertAllClose(results[0], [[[0, 0], [0, 0]], [[1, 0], [0, 0]]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testSplitNoPadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[ |
| [0, 1, 2, 3], |
| [4, 5, 6, 7], |
| [8, 9, 10, 11], |
| [12, 13, 14, 15], |
| ], |
| input_dtype=dtype, |
| num_outputs=4, |
| num_splits=[2, 2]) |
| results = sess.run(split) |
| self.assertLen(results, 4) |
| self.assertAllClose(results[0], [[0, 1], [4, 5]]) |
| self.assertAllClose(results[1], [[2, 3], [6, 7]]) |
| self.assertAllClose(results[2], [[8, 9], [12, 13]]) |
| self.assertAllClose(results[3], [[10, 11], [14, 15]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testSplitPartialPadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[ |
| [0, 1, 2], |
| [3, 4, 5], |
| [6, 7, 8], |
| ], |
| input_dtype=dtype, |
| num_outputs=4, |
| num_splits=[2, 2], |
| paddings=[1, 1]) |
| results = sess.run(split) |
| self.assertLen(results, 4) |
| self.assertAllClose(results[0], [[0, 1], [3, 4]]) |
| self.assertAllClose(results[1], [[2, 0], [5, 0]]) |
| self.assertAllClose(results[2], [[6, 7], [0, 0]]) |
| self.assertAllClose(results[3], [[8, 0], [0, 0]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_split_graph), |
| ('Resource', create_resource_split_graph)) |
| def testSplitCompletePadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=[[0], [1]], |
| input_dtype=dtype, |
| num_outputs=4, |
| num_splits=[2, 2], |
| paddings=[2, 3]) |
| results = sess.run(split) |
| self.assertLen(results, 4) |
| self.assertAllClose(results[0], [[0, 0], [1, 0]]) |
| self.assertAllClose(results[1], [[0, 0], [0, 0]]) |
| self.assertAllClose(results[2], [[0, 0], [0, 0]]) |
| self.assertAllClose(results[3], [[0, 0], [0, 0]]) |
| |
| @parameterized.named_parameters( |
| ('1Tensor', create_tensor_split_graph, 1), |
| ('2Tensor', create_tensor_split_graph, 2), |
| ('3Tensor', create_tensor_split_graph, 3), |
| ('4Tensor', create_tensor_split_graph, 4), |
| ('5Tensor', create_tensor_split_graph, 5), |
| ('6Tensor', create_tensor_split_graph, 6), |
| ('7Tensor', create_tensor_split_graph, 7), |
| ('8Tensor', create_tensor_split_graph, 8), |
| ('1Resource', create_resource_split_graph, 1), |
| ('2Resource', create_resource_split_graph, 2), |
| ('3Resource', create_resource_split_graph, 3), |
| ('4Resource', create_resource_split_graph, 4), |
| ('5Resource', create_resource_split_graph, 5), |
| ('6Resource', create_resource_split_graph, 6), |
| ('7Resource', create_resource_split_graph, 7), |
| ('8Resource', create_resource_split_graph, 8), |
| ) |
| def testRanked(self, graph_fn, rank): |
| num_splits = [2] * rank |
| num_outputs = 2 << (rank - 1) |
| input_value = np.reshape(np.arange(np.product(num_splits)), num_splits) |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| split = graph_fn( |
| sess, |
| input_value=input_value, |
| input_dtype=dtype, |
| num_outputs=num_outputs, |
| num_splits=num_splits) |
| results = sess.run(split) |
| self.assertLen(results, num_outputs) |
| for i, result in enumerate(results): |
| expected_output = np.reshape(i, [1] * rank).astype(dtype) |
| self.assertAllClose(result, expected_output) |
| |
| |
| def create_tensor_concat_graph( |
| sess: Session, |
| input_values: List[Any], |
| input_dtype: Any, |
| num_concats: List[int], |
| paddings: Optional[List[int]] = None, |
| output_shape: Optional[List[int]] = None) -> Tensor: |
| del sess |
| del output_shape |
| |
| const_input_ops = [ |
| constant_op.constant(i, dtype=input_dtype) for i in input_values |
| ] |
| return gen_tpu_ops.xla_concat_nd(const_input_ops, num_concats, paddings) |
| |
| |
| def create_resource_concat_graph( |
| sess: Session, |
| input_values: List[Any], |
| input_dtype: Any, |
| num_concats: List[int], |
| paddings: Optional[List[int]] = None, |
| output_shape: Optional[List[int]] = None) -> Tensor: |
| variable_shape = [] if output_shape is None else output_shape |
| variable = resource_variable_ops.ResourceVariable( |
| initial_value=np.zeros(variable_shape, dtype=input_dtype), |
| dtype=input_dtype) |
| sess.run(variables.variables_initializer([variable])) |
| const_input_ops = [ |
| constant_op.constant(i, dtype=input_dtype) for i in input_values |
| ] |
| concat = gen_tpu_ops.assign_variable_xla_concat_nd(variable.handle, |
| const_input_ops, |
| num_concats, paddings) |
| with control_dependencies([concat]): |
| return variable.read_value() |
| |
| |
| class XlaConcatNDOpTest(xla_test.XLATestCase, parameterized.TestCase): |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testConcatDimensionZero(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[[[0]]]], |
| input_dtype=dtype, |
| num_concats=[1, 1, 0]) |
| with self.assertRaisesOpError('index 2 must be positive, but got 0'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testConcatDimensionNegative(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[[[0]]]], |
| input_dtype=dtype, |
| num_concats=[1, -1, 1]) |
| with self.assertRaisesOpError('index 1 must be positive, but got -1'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testNumInputsMismatch(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, input_values=[[0, 1]], input_dtype=dtype, num_concats=[2]) |
| with self.assertRaisesOpError('\'N\' must match number of slices 2'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testPaddingsLengthMismatch(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[[0, 1], [2, 3]]], |
| input_dtype=dtype, |
| num_concats=[1, 1], |
| paddings=[0]) |
| with self.assertRaisesOpError('length 2, but got 1'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testPaddingsNegative(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[[0, 1], [2, 3]]], |
| input_dtype=dtype, |
| num_concats=[1, 1], |
| paddings=[0, -1]) |
| with self.assertRaisesOpError('non-negative, but got -1 at index 1'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testInputRankConcatMismatch(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, input_values=[[0]], input_dtype=dtype, num_concats=[1, 1]) |
| with self.assertRaisesOpError( |
| '\'num_concats\' length 2, but got rank 1'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testDifferentShapedInputs(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[0], [1, 2]], |
| input_dtype=dtype, |
| num_concats=[2]) |
| with self.assertRaisesOpError( |
| r'same expected shape \[1\], but got \[2\] at index 1'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testPaddingExceedsOutputDimSize(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[0]], |
| input_dtype=dtype, |
| num_concats=[1], |
| paddings=[2]) |
| with self.assertRaisesOpError( |
| 'exceed expected output shape dimension 1 at index 0, but got 2'): |
| sess.run(concat) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testNoConcats(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[[[0, 1], [2, 3]], [[4, 5], [6, 7]]]], |
| input_dtype=dtype, |
| num_concats=[1, 1, 1], |
| output_shape=[2, 2, 2]) |
| result = sess.run(concat) |
| self.assertAllClose(result, [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testNoConcatsWithPadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[[[[0, 1], [2, 3]], [[4, 5], [6, 7]]]], |
| input_dtype=dtype, |
| num_concats=[1, 1, 1], |
| output_shape=[1, 1, 1], |
| paddings=[1, 1, 1]) |
| result = sess.run(concat) |
| self.assertAllClose(result, [[[0]]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testConcatNoPadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[ |
| [[0, 1], [2, 3]], |
| [[4, 5], [6, 7]], |
| [[8, 9], [10, 11]], |
| [[12, 13], [14, 15]], |
| ], |
| input_dtype=dtype, |
| num_concats=[2, 2], |
| output_shape=[4, 4]) |
| result = sess.run(concat) |
| self.assertAllClose( |
| result, |
| [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testConcatPartialPadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[ |
| [[0, 1], [2, 3]], |
| [[4, 5], [6, 7]], |
| [[8, 9], [10, 11]], |
| [[12, 13], [14, 15]], |
| ], |
| input_dtype=dtype, |
| num_concats=[2, 2], |
| output_shape=[3, 3], |
| paddings=[1, 1]) |
| result = sess.run(concat) |
| self.assertAllClose(result, [[0, 1, 4], [2, 3, 6], [8, 9, 12]]) |
| |
| @parameterized.named_parameters(('Tensor', create_tensor_concat_graph), |
| ('Resource', create_resource_concat_graph)) |
| def testConcatCompletePadding(self, graph_fn): |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=[ |
| [[0, 1], [2, 3]], |
| [[4, 5], [6, 7]], |
| [[8, 9], [10, 11]], |
| [[12, 13], [14, 15]], |
| ], |
| input_dtype=dtype, |
| num_concats=[2, 2], |
| output_shape=[2, 2], |
| paddings=[2, 2]) |
| result = sess.run(concat) |
| self.assertAllClose(result, [[0, 1], [2, 3]]) |
| |
| @parameterized.named_parameters( |
| ('1Tensor', create_tensor_concat_graph, 1), |
| ('2Tensor', create_tensor_concat_graph, 2), |
| ('3Tensor', create_tensor_concat_graph, 3), |
| ('4Tensor', create_tensor_concat_graph, 4), |
| ('5Tensor', create_tensor_concat_graph, 5), |
| ('6Tensor', create_tensor_concat_graph, 6), |
| ('7Tensor', create_tensor_concat_graph, 7), |
| ('8Tensor', create_tensor_concat_graph, 8), |
| ('1Resource', create_resource_concat_graph, 1), |
| ('2Resource', create_resource_concat_graph, 2), |
| ('3Resource', create_resource_concat_graph, 3), |
| ('4Resource', create_resource_concat_graph, 4), |
| ('5Resource', create_resource_concat_graph, 5), |
| ('6Resource', create_resource_concat_graph, 6), |
| ('7Resource', create_resource_concat_graph, 7), |
| ('8Resource', create_resource_concat_graph, 8), |
| ) |
| def testRanked(self, graph_fn, rank): |
| num_concats = [2] * rank |
| num_inputs = 2 << (rank - 1) |
| input_values = [np.reshape(i, [1] * rank) for i in range(num_inputs)] |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| concat = graph_fn( |
| sess, |
| input_values=input_values, |
| input_dtype=dtype, |
| num_concats=num_concats, |
| output_shape=num_concats) |
| result = sess.run(concat) |
| expected_output = np.arange(0, |
| num_inputs).reshape(num_concats).astype(dtype) |
| self.assertAllClose(result, expected_output) |
| |
| |
| def create_tensor_roundtrip_graph( |
| sess: Session, |
| value: Any, |
| dtype: Any, |
| num_partitions: List[int], |
| paddings: Optional[List[int]] = None) -> Tensor: |
| del sess |
| |
| const_input_op = constant_op.constant(value, dtype=dtype) |
| split = gen_tpu_ops.xla_split_nd( |
| const_input_op, |
| np.prod(num_partitions), |
| num_partitions, |
| paddings=paddings) |
| concat = gen_tpu_ops.xla_concat_nd(split, num_partitions, paddings) |
| return math_ops.equal(const_input_op, concat) |
| |
| |
| def create_resource_roundtrip_graph( |
| sess: Session, |
| value: Any, |
| dtype: Any, |
| num_partitions: List[int], |
| paddings: Optional[List[int]] = None) -> Tensor: |
| variable = resource_variable_ops.ResourceVariable( |
| initial_value=value, dtype=dtype) |
| sess.run(variables.variables_initializer([variable])) |
| split = gen_tpu_ops.read_variable_xla_split_nd( |
| variable.handle, |
| dtype, |
| np.prod(num_partitions), |
| num_partitions, |
| paddings=paddings) |
| concat = gen_tpu_ops.assign_variable_xla_concat_nd(variable.handle, split, |
| num_partitions, paddings) |
| with control_dependencies([concat]): |
| return math_ops.equal(variable.read_value(), |
| constant_op.constant(value, dtype=dtype)) |
| |
| |
| class XlaSplitConcatNDTest(xla_test.XLATestCase, parameterized.TestCase): |
| |
| @parameterized.named_parameters( |
| ('1Tensor', create_tensor_roundtrip_graph, 1), |
| ('2Tensor', create_tensor_roundtrip_graph, 2), |
| ('3Tensor', create_tensor_roundtrip_graph, 3), |
| ('4Tensor', create_tensor_roundtrip_graph, 4), |
| ('5Tensor', create_tensor_roundtrip_graph, 5), |
| ('6Tensor', create_tensor_roundtrip_graph, 6), |
| ('7Tensor', create_tensor_roundtrip_graph, 7), |
| ('8Tensor', create_tensor_roundtrip_graph, 8), |
| ('1Resource', create_resource_roundtrip_graph, 1), |
| ('2Resource', create_resource_roundtrip_graph, 2), |
| ('3Resource', create_resource_roundtrip_graph, 3), |
| ('4Resource', create_resource_roundtrip_graph, 4), |
| ('5Resource', create_resource_roundtrip_graph, 5), |
| ('6Resource', create_resource_roundtrip_graph, 6), |
| ('7Resource', create_resource_roundtrip_graph, 7), |
| ('8Resource', create_resource_roundtrip_graph, 8), |
| ) |
| def testNoPadding(self, graph_fn, rank): |
| num_partitions = [2] * rank |
| shape = [4] * rank |
| value = np.arange(0, np.prod(shape)).reshape(shape) |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| validate = graph_fn(sess, value, dtype, num_partitions) |
| result = sess.run(validate) |
| self.assertAllEqual(result, np.broadcast_to(True, shape)) |
| |
| @parameterized.named_parameters( |
| ('1Tensor', create_tensor_roundtrip_graph, 1), |
| ('2Tensor', create_tensor_roundtrip_graph, 2), |
| ('3Tensor', create_tensor_roundtrip_graph, 3), |
| ('4Tensor', create_tensor_roundtrip_graph, 4), |
| ('5Tensor', create_tensor_roundtrip_graph, 5), |
| ('6Tensor', create_tensor_roundtrip_graph, 6), |
| ('7Tensor', create_tensor_roundtrip_graph, 7), |
| ('8Tensor', create_tensor_roundtrip_graph, 8), |
| ('1Resource', create_resource_roundtrip_graph, 1), |
| ('2Resource', create_resource_roundtrip_graph, 2), |
| ('3Resource', create_resource_roundtrip_graph, 3), |
| ('4Resource', create_resource_roundtrip_graph, 4), |
| ('5Resource', create_resource_roundtrip_graph, 5), |
| ('6Resource', create_resource_roundtrip_graph, 6), |
| ('7Resource', create_resource_roundtrip_graph, 7), |
| ('8Resource', create_resource_roundtrip_graph, 8), |
| ) |
| def testPartialPadding(self, graph_fn, rank): |
| num_partitions = [2] * rank |
| shape = [4] * rank |
| value = np.arange(0, np.prod(shape)).reshape(shape) |
| paddings = [2] * rank |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| validate = graph_fn(sess, value, dtype, num_partitions, paddings) |
| result = sess.run(validate) |
| self.assertAllEqual(result, np.broadcast_to(True, shape)) |
| |
| @parameterized.named_parameters( |
| ('1Tensor', create_tensor_roundtrip_graph, 1), |
| ('2Tensor', create_tensor_roundtrip_graph, 2), |
| ('3Tensor', create_tensor_roundtrip_graph, 3), |
| ('4Tensor', create_tensor_roundtrip_graph, 4), |
| ('5Tensor', create_tensor_roundtrip_graph, 5), |
| ('6Tensor', create_tensor_roundtrip_graph, 6), |
| ('7Tensor', create_tensor_roundtrip_graph, 7), |
| ('8Tensor', create_tensor_roundtrip_graph, 8), |
| ('1Resource', create_resource_roundtrip_graph, 1), |
| ('2Resource', create_resource_roundtrip_graph, 2), |
| ('3Resource', create_resource_roundtrip_graph, 3), |
| ('4Resource', create_resource_roundtrip_graph, 4), |
| ('5Resource', create_resource_roundtrip_graph, 5), |
| ('6Resource', create_resource_roundtrip_graph, 6), |
| ('7Resource', create_resource_roundtrip_graph, 7), |
| ('8Resource', create_resource_roundtrip_graph, 8), |
| ) |
| def testCompletePadding(self, graph_fn, rank): |
| num_partitions = [2] * rank |
| shape = [4] * rank |
| value = np.arange(0, np.prod(shape)).reshape(shape) |
| paddings = [4] * rank |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.device_scope(): |
| validate = graph_fn(sess, value, dtype, num_partitions, paddings) |
| result = sess.run(validate) |
| self.assertAllEqual(result, np.broadcast_to(True, shape)) |
| |
| |
| if __name__ == '__main__': |
| test.main() |