blob: 8cc8a61353e29b2dc4657ab690d452410aee1426 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.MatrixTriangularSolve."""
import itertools
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.platform import test
def MakePlaceholder(x, dtype=None):
return array_ops.placeholder(
dtypes.as_dtype(x.dtype) if dtype is None else dtype, shape=x.shape)
class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
# MatrixTriangularSolve defined for float64, float32, complex64, complex128
# (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
@property
def float_types(self):
return set(super(MatrixTriangularSolveOpTest,
self).float_types).intersection(
(np.float64, np.float32, np.complex64, np.complex128))
def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
placeholder_b, a, clean_a, b, verification,
atol):
feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
verification_np = sess.run(verification, feed_dict)
broadcasted_shape = a.shape[:-2] + (b.shape[-2], b.shape[-1])
broadcasted_b = b + np.zeros(shape=broadcasted_shape, dtype=b.dtype)
self.assertAllClose(broadcasted_b, verification_np, atol=atol)
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol, dtype=None):
clean_a = np.tril(a) if lower else np.triu(a)
with self.session() as sess:
placeholder_a = MakePlaceholder(a, dtype)
placeholder_ca = MakePlaceholder(clean_a, dtype)
placeholder_b = MakePlaceholder(b, dtype)
with self.test_scope():
x = linalg_ops.matrix_triangular_solve(
placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
verification = test_util.matmul_without_tf32(
placeholder_ca, x, adjoint_a=adjoint)
self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
placeholder_b, a, clean_a, b,
verification, atol)
def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4, dtype=None):
transp = lambda x: np.swapaxes(x, -1, -2)
for lower, adjoint in itertools.product([True, False], repeat=2):
self._VerifyTriangularSolve(
a if lower else transp(a), b, lower, adjoint, atol, dtype=dtype)
def testBasic(self):
rng = np.random.RandomState(0)
a = np.tril(rng.randn(5, 5))
b = rng.randn(5, 7)
for dtype in self.float_types:
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
def testBfloat16(self):
rng = np.random.RandomState(0)
a = np.tril(rng.randn(5, 5))
b = rng.randn(5, 7)
self._VerifyTriangularSolveCombo(a, b, atol=5e-2, dtype=dtypes.bfloat16)
def testBasicNotActuallyTriangular(self):
rng = np.random.RandomState(0)
a = rng.randn(5, 5) # the `a` matrix is not lower-triangular
b = rng.randn(5, 7)
for dtype in self.float_types:
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
def testBasicComplexDtypes(self):
if xla_test.test.is_built_with_rocm():
# The folowing subtest invokes the call to "BlasTrsm"
# That operation is currently not supported on the ROCm platform
self.skipTest("BlasTrsm op for complex types is not supported in ROCm")
rng = np.random.RandomState(0)
a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
for dtype in self.complex_types:
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
def testBatch(self):
rng = np.random.RandomState(0)
shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)),
((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))]
tuples = itertools.product(self.float_types, shapes)
for dtype, (a_shape, b_shape) in tuples:
n = a_shape[-1]
a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
b = rng.randn(*b_shape)
self._VerifyTriangularSolveCombo(
a.astype(dtype), b.astype(dtype), atol=1e-3)
def testBatchBroadcast(self):
rng = np.random.RandomState(0)
shapes = [((3, 3), (4, 3, 5)), ((1, 2, 2), (3, 2, 1)), ((1, 1), (1, 1, 2)),
((1, 3, 4, 4), (2, 1, 4, 1))]
tuples = itertools.product(self.float_types, shapes)
for dtype, (a_shape, b_shape) in tuples:
n = a_shape[-1]
a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
b = rng.randn(*b_shape)
self._VerifyTriangularSolveCombo(
a.astype(dtype), b.astype(dtype), atol=1e-3)
def testLarge(self):
n = 1024
rng = np.random.RandomState(0)
a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n)
b = rng.randn(n, n)
self._VerifyTriangularSolve(
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
@test_util.disable_mlir_bridge("Error handling")
def testNonSquareCoefficientMatrix(self):
rng = np.random.RandomState(0)
for dtype in self.float_types:
a = rng.randn(3, 4).astype(dtype)
b = rng.randn(4, 4).astype(dtype)
with self.test_scope():
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
linalg_ops.matrix_triangular_solve(a, b)
@test_util.run_v2_only # Different error types
@test_util.disable_mlir_bridge("Error handling")
def testWrongDimensionsV2(self):
randn = np.random.RandomState(0).randn
for dtype in self.float_types:
lhs = constant_op.constant(randn(3, 3), dtype=dtype)
rhs = constant_op.constant(randn(4, 3), dtype=dtype)
with self.assertRaises(errors.InvalidArgumentError):
linalg_ops.matrix_triangular_solve(lhs, rhs)
with self.assertRaises(errors.InvalidArgumentError):
linalg_ops.matrix_triangular_solve(lhs, rhs)
@test_util.run_v1_only("Different error types")
@test_util.disable_mlir_bridge("Error handling")
def testWrongDimensionsV1(self):
randn = np.random.RandomState(0).randn
for dtype in self.float_types:
lhs = constant_op.constant(randn(3, 3), dtype=dtype)
rhs = constant_op.constant(randn(4, 3), dtype=dtype)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(lhs, rhs)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(lhs, rhs)
if __name__ == "__main__":
test.main()