blob: 2cc6279ed7ae59d8dc3ea8ac9c95039c1ec548ed [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for operators with > 3 or arbitrary numbers of arguments."""
import unittest
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class NAryOpsTest(xla_test.XLATestCase):
def _testNAry(self, op, args, expected, equality_fn=None):
with self.session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
for arg in args
]
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
output = op(placeholders)
result = session.run(output, feeds)
if not equality_fn:
equality_fn = self.assertAllClose
equality_fn(result, expected, rtol=1e-3)
def _nAryListCheck(self, results, expected, **kwargs):
self.assertEqual(len(results), len(expected))
for (r, e) in zip(results, expected):
self.assertAllClose(r, e, **kwargs)
def _testNAryLists(self, op, args, expected):
self._testNAry(op, args, expected, equality_fn=self._nAryListCheck)
def testFloat(self):
self._testNAry(math_ops.add_n,
[np.array([[1, 2, 3]], dtype=np.float32)],
expected=np.array([[1, 2, 3]], dtype=np.float32))
self._testNAry(math_ops.add_n,
[np.array([1, 2], dtype=np.float32),
np.array([10, 20], dtype=np.float32)],
expected=np.array([11, 22], dtype=np.float32))
self._testNAry(math_ops.add_n,
[np.array([-4], dtype=np.float32),
np.array([10], dtype=np.float32),
np.array([42], dtype=np.float32)],
expected=np.array([48], dtype=np.float32))
def testComplex(self):
for dtype in self.complex_types:
self._testNAry(
math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)],
expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype))
self._testNAry(
math_ops.add_n, [
np.array([1 + 2j, 2 - 3j], dtype=dtype),
np.array([10j, 20], dtype=dtype)
],
expected=np.array([1 + 12j, 22 - 3j], dtype=dtype))
self._testNAry(
math_ops.add_n, [
np.array([-4, 5j], dtype=dtype),
np.array([2 + 10j, -2], dtype=dtype),
np.array([42j, 3 + 3j], dtype=dtype)
],
expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype))
@unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
def testIdentityN(self):
self._testNAryLists(array_ops.identity_n,
[np.array([[1, 2, 3]], dtype=np.float32)],
expected=[np.array([[1, 2, 3]], dtype=np.float32)])
self._testNAryLists(array_ops.identity_n,
[np.array([[1, 2], [3, 4]], dtype=np.float32),
np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)],
expected=[
np.array([[1, 2], [3, 4]], dtype=np.float32),
np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)])
self._testNAryLists(array_ops.identity_n,
[np.array([[1], [2], [3], [4]], dtype=np.int32),
np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)],
expected=[
np.array([[1], [2], [3], [4]], dtype=np.int32),
np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)])
def testConcat(self):
self._testNAry(
lambda x: array_ops.concat(x, 0), [
np.array(
[[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
[[7, 8, 9], [10, 11, 12]], dtype=np.float32)
],
expected=np.array(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32))
self._testNAry(
lambda x: array_ops.concat(x, 1), [
np.array(
[[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
[[7, 8, 9], [10, 11, 12]], dtype=np.float32)
],
expected=np.array(
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
def testOneHot(self):
with self.session() as session, self.test_scope():
indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32))
op = array_ops.one_hot(indices,
np.int32(4),
on_value=np.float32(7), off_value=np.float32(3))
output = session.run(op)
expected = np.array([[[3, 3, 7, 3], [3, 3, 3, 7]],
[[7, 3, 3, 3], [3, 7, 3, 3]]],
dtype=np.float32)
self.assertAllEqual(output, expected)
op = array_ops.one_hot(indices,
np.int32(4),
on_value=np.int32(2), off_value=np.int32(1),
axis=1)
output = session.run(op)
expected = np.array([[[1, 1], [1, 1], [2, 1], [1, 2]],
[[2, 1], [1, 2], [1, 1], [1, 1]]],
dtype=np.int32)
self.assertAllEqual(output, expected)
def testSplitV(self):
with self.session() as session:
with self.test_scope():
output = session.run(
array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
dtype=np.float32),
[2, 2], 1))
expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32),
np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
self.assertAllEqual(output, expected)
def testSplitVNegativeSizes(self):
with self.session() as session:
with self.test_scope():
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
"Split size at index 1 must be >= .*. Got: -2"):
_ = session.run(
array_ops.split(np.array([1, 2, 3], dtype=np.float32), [-1, -2],
axis=0))
def testStridedSlice(self):
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[], [], []], dtype=np.float32),
np.array([1, 0], dtype=np.int32),
np.array([3, 0], dtype=np.int32),
np.array([1, 1], dtype=np.int32)],
expected=np.array([[], []], dtype=np.float32))
if np.int64 in self.int_types:
self._testNAry(
lambda x: array_ops.strided_slice(*x), [
np.array([[], [], []], dtype=np.float32), np.array(
[1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64),
np.array([1, 1], dtype=np.int64)
],
expected=np.array([[], []], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32),
np.array([1, 1], dtype=np.int32),
np.array([3, 3], dtype=np.int32),
np.array([1, 1], dtype=np.int32)],
expected=np.array([[5, 6], [8, 9]], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32),
np.array([0, 2], dtype=np.int32),
np.array([2, 0], dtype=np.int32),
np.array([1, -1], dtype=np.int32)],
expected=np.array([[3, 2], [6, 5]], dtype=np.float32))
self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1],
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32)],
expected=np.array([[[3, 2, 1]], [[6, 5, 4]]],
dtype=np.float32))
self._testNAry(lambda x: x[0][1, :, array_ops.newaxis],
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=np.float32)],
expected=np.array([[4], [5], [6]], dtype=np.float32))
def testStridedSliceGrad(self):
# Tests cases where input shape is empty.
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([], dtype=np.int32),
np.array([], dtype=np.int32),
np.array([], dtype=np.int32),
np.array([], dtype=np.int32),
np.float32(0.5)],
expected=np.array(np.float32(0.5), dtype=np.float32))
# Tests case where input shape is non-empty, but gradients are empty.
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3], dtype=np.int32),
np.array([0], dtype=np.int32),
np.array([0], dtype=np.int32),
np.array([1], dtype=np.int32),
np.array([], dtype=np.float32)],
expected=np.array([0, 0, 0], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 0], dtype=np.int32),
np.array([1, 0], dtype=np.int32),
np.array([3, 0], dtype=np.int32),
np.array([1, 1], dtype=np.int32),
np.array([[], []], dtype=np.float32)],
expected=np.array([[], [], []], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 3], dtype=np.int32),
np.array([1, 1], dtype=np.int32),
np.array([3, 3], dtype=np.int32),
np.array([1, 1], dtype=np.int32),
np.array([[5, 6], [8, 9]], dtype=np.float32)],
expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]],
dtype=np.float32))
def ssg_test(x):
return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4,
new_axis_mask=0x1)
self._testNAry(ssg_test,
[np.array([3, 1, 3], dtype=np.int32),
np.array([0, 0, 0, 2], dtype=np.int32),
np.array([0, 3, 1, -4], dtype=np.int32),
np.array([1, 2, 1, -3], dtype=np.int32),
np.array([[[1], [2]]], dtype=np.float32)],
expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]],
dtype=np.float32))
ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15)
self._testNAry(ssg_test2,
[np.array([4, 4], dtype=np.int32),
np.array([0, 0, 0, 1, 0], dtype=np.int32),
np.array([0, 3, 0, 4, 0], dtype=np.int32),
np.array([1, 2, 1, 2, 1], dtype=np.int32),
np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)],
expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4],
[0, 0, 0, 0]], dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 3], dtype=np.int32),
np.array([0, 2], dtype=np.int32),
np.array([2, 0], dtype=np.int32),
np.array([1, -1], dtype=np.int32),
np.array([[1, 2], [3, 4]], dtype=np.float32)],
expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]],
dtype=np.float32))
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
[np.array([3, 3], dtype=np.int32),
np.array([2, 2], dtype=np.int32),
np.array([0, 1], dtype=np.int32),
np.array([-1, -2], dtype=np.int32),
np.array([[1], [2]], dtype=np.float32)],
expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]],
dtype=np.float32))
if __name__ == "__main__":
googletest.main()