| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for XLA implementation of tf.linalg.solve.""" |
| |
| import os |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.ops import linalg_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.platform import googletest |
| |
| |
| class MatrixSolveOpTest(xla_test.XLATestCase, parameterized.TestCase): |
| |
| def _verifySolve(self, x, y, adjoint): |
| for np_type in self.float_types & {np.float32, np.float64}: |
| tol = 1e-4 if np_type == np.float32 else 1e-12 |
| a = x.astype(np_type) |
| b = y.astype(np_type) |
| np_ans = np.linalg.solve(np.swapaxes(a, -2, -1) if adjoint else a, b) |
| with self.session() as sess: |
| with self.test_scope(): |
| tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) |
| out = sess.run(tf_ans) |
| self.assertEqual(tf_ans.shape, out.shape) |
| self.assertEqual(np_ans.shape, out.shape) |
| self.assertAllClose(np_ans, out, atol=tol, rtol=tol) |
| |
| @parameterized.named_parameters( |
| ("Scalar", 1, 1, [], [], False), |
| ("Vector", 5, 1, [], [], False), |
| ("MultipleRHS", 5, 4, [], [], False), |
| ("Adjoint", 5, 4, [], [], True), |
| ("BatchedScalar", 1, 4, [2], [2], False), |
| ("BatchedVector", 5, 4, [2], [2], False), |
| ("BatchedRank2", 5, 4, [7, 4], [7, 4], False), |
| ("BatchedAdjoint", 5, 4, [7, 4], [7, 4], True), |
| ) |
| def testSolve(self, n, nrhs, batch_dims, rhs_batch_dims, adjoint): |
| matrix = np.random.normal(-5.0, 5.0, batch_dims + [n, n]) |
| rhs = np.random.normal(-5.0, 5.0, rhs_batch_dims + [n, nrhs]) |
| self._verifySolve(matrix, rhs, adjoint=adjoint) |
| |
| @parameterized.named_parameters( |
| ("Simple", False), |
| ("Adjoint", True), |
| ) |
| def testConcurrent(self, adjoint): |
| with self.session() as sess: |
| lhs1 = random_ops.random_normal([3, 3], seed=42) |
| lhs2 = random_ops.random_normal([3, 3], seed=42) |
| rhs1 = random_ops.random_normal([3, 3], seed=42) |
| rhs2 = random_ops.random_normal([3, 3], seed=42) |
| with self.test_scope(): |
| s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint) |
| s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint) |
| self.assertAllEqual(*sess.run([s1, s2])) |
| |
| |
| if __name__ == "__main__": |
| os.environ["XLA_FLAGS"] = ("--xla_gpu_enable_cublaslt=true " + |
| os.environ.get("XLA_FLAGS", "")) |
| googletest.main() |