| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tridiagonal solve ops.""" |
| |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors_impl |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gradients as gradient_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops.linalg import linalg_impl |
| from tensorflow.python.platform import test |
| |
| _sample_diags = np.array([[2, 1, 4, 0], [1, 3, 2, 2], [0, 1, -1, 1]], |
| dtype=np.float32) |
| _sample_rhs = np.array([1, 2, 3, 4], dtype=np.float32) |
| _sample_result = np.array([-9, 5, -4, 4], dtype=np.float32) |
| |
| |
| def _tfconst(array): |
| return constant_op.constant(array, dtype=dtypes.float32) |
| |
| |
| def _tf_ones(shape): |
| return array_ops.ones(shape, dtype=dtypes.float64) |
| |
| |
| class TridiagonalSolveOpsTest(xla_test.XLATestCase): |
| """Test for tri-diagonal matrix related ops.""" |
| |
| def testTridiagonalSolverSolves1Rhs(self): |
| np.random.seed(19) |
| |
| batch_size = 8 |
| num_dims = 11 |
| |
| diagonals_np = np.random.normal(size=(batch_size, 3, |
| num_dims)).astype(np.float32) |
| rhs_np = np.random.normal(size=(batch_size, num_dims, 1)).astype(np.float32) |
| |
| with self.session() as sess, self.test_scope(): |
| diags = array_ops.placeholder( |
| shape=(batch_size, 3, num_dims), dtype=dtypes.float32) |
| rhs = array_ops.placeholder( |
| shape=(batch_size, num_dims, 1), dtype=dtypes.float32) |
| x_np = sess.run( |
| linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False), |
| feed_dict={ |
| diags: diagonals_np, |
| rhs: rhs_np |
| })[:, :, 0] |
| |
| superdiag_np = diagonals_np[:, 0] |
| diag_np = diagonals_np[:, 1] |
| subdiag_np = diagonals_np[:, 2] |
| |
| y = np.zeros((batch_size, num_dims), dtype=np.float32) |
| |
| for i in range(num_dims): |
| if i == 0: |
| y[:, i] = ( |
| diag_np[:, i] * x_np[:, i] + superdiag_np[:, i] * x_np[:, i + 1]) |
| elif i == num_dims - 1: |
| y[:, i] = ( |
| subdiag_np[:, i] * x_np[:, i - 1] + diag_np[:, i] * x_np[:, i]) |
| else: |
| y[:, i] = ( |
| subdiag_np[:, i] * x_np[:, i - 1] + diag_np[:, i] * x_np[:, i] + |
| superdiag_np[:, i] * x_np[:, i + 1]) |
| |
| self.assertAllClose(y, rhs_np[:, :, 0], rtol=1e-4, atol=1e-4) |
| |
| def testTridiagonalSolverSolvesKRhs(self): |
| np.random.seed(19) |
| |
| batch_size = 8 |
| num_dims = 11 |
| num_rhs = 5 |
| |
| diagonals_np = np.random.normal(size=(batch_size, 3, |
| num_dims)).astype(np.float32) |
| rhs_np = np.random.normal(size=(batch_size, num_dims, |
| num_rhs)).astype(np.float32) |
| |
| with self.session() as sess, self.test_scope(): |
| diags = array_ops.placeholder( |
| shape=(batch_size, 3, num_dims), dtype=dtypes.float32) |
| rhs = array_ops.placeholder( |
| shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32) |
| x_np = sess.run( |
| linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False), |
| feed_dict={ |
| diags: diagonals_np, |
| rhs: rhs_np |
| }) |
| |
| superdiag_np = diagonals_np[:, 0] |
| diag_np = diagonals_np[:, 1] |
| subdiag_np = diagonals_np[:, 2] |
| |
| for eq in range(num_rhs): |
| y = np.zeros((batch_size, num_dims), dtype=np.float32) |
| for i in range(num_dims): |
| if i == 0: |
| y[:, i] = ( |
| diag_np[:, i] * x_np[:, i, eq] + |
| superdiag_np[:, i] * x_np[:, i + 1, eq]) |
| elif i == num_dims - 1: |
| y[:, i] = ( |
| subdiag_np[:, i] * x_np[:, i - 1, eq] + |
| diag_np[:, i] * x_np[:, i, eq]) |
| else: |
| y[:, i] = ( |
| subdiag_np[:, i] * x_np[:, i - 1, eq] + |
| diag_np[:, i] * x_np[:, i, eq] + |
| superdiag_np[:, i] * x_np[:, i + 1, eq]) |
| |
| self.assertAllClose(y, rhs_np[:, :, eq], rtol=1e-4, atol=1e-4) |
| |
| # All the following is adapted from tridiagonal_solve_op_test.py |
| def _test(self, |
| diags, |
| rhs, |
| expected, |
| diags_format="compact", |
| transpose_rhs=False): |
| with self.session() as sess, self.test_scope(): |
| self.assertAllClose( |
| sess.run( |
| linalg_impl.tridiagonal_solve( |
| _tfconst(diags), |
| _tfconst(rhs), |
| diags_format, |
| transpose_rhs, |
| conjugate_rhs=False, |
| partial_pivoting=False)), |
| np.asarray(expected, dtype=np.float32)) |
| |
| def _testWithDiagonalLists(self, |
| diags, |
| rhs, |
| expected, |
| diags_format="compact", |
| transpose_rhs=False): |
| with self.session() as sess, self.test_scope(): |
| self.assertAllClose( |
| sess.run( |
| linalg_impl.tridiagonal_solve([_tfconst(x) for x in diags], |
| _tfconst(rhs), |
| diags_format, |
| transpose_rhs, |
| conjugate_rhs=False, |
| partial_pivoting=False)), |
| sess.run(_tfconst(expected))) |
| |
| def testReal(self): |
| self._test(diags=_sample_diags, rhs=_sample_rhs, expected=_sample_result) |
| |
| # testComplex is skipped as complex type is not yet supported. |
| |
| def test3x3(self): |
| self._test( |
| diags=[[2.0, -1.0, 0.0], [1.0, 3.0, 1.0], [0.0, -1.0, -2.0]], |
| rhs=[1.0, 2.0, 3.0], |
| expected=[-3.0, 2.0, 7.0]) |
| |
| def test2x2(self): |
| self._test( |
| diags=[[2.0, 0.0], [1.0, 3.0], [0.0, 1.0]], |
| rhs=[1.0, 4.0], |
| expected=[-5.0, 3.0]) |
| |
| def test1x1(self): |
| self._test(diags=[[0], [3], [0]], rhs=[6], expected=[2]) |
| |
| def test0x0(self): |
| self._test( |
| diags=np.zeros(shape=(3, 0), dtype=np.float32), |
| rhs=np.zeros(shape=(0, 1), dtype=np.float32), |
| expected=np.zeros(shape=(0, 1), dtype=np.float32)) |
| |
| def test2x2WithMultipleRhs(self): |
| self._test( |
| diags=[[2, 0], [1, 3], [0, 1]], |
| rhs=[[1, 2, 3], [4, 8, 12]], |
| expected=[[-5, -10, -15], [3, 6, 9]]) |
| |
| def test1x1WithMultipleRhs(self): |
| self._test(diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]]) |
| |
| # test1x1NotInvertible is skipped as runtime error not raised for now. |
| |
| # test2x2NotInvertible is skipped as runtime error not raised for now. |
| |
| @test_util.disable_mlir_bridge("Error messages differ") |
| def testPartialPivotingRaises(self): |
| np.random.seed(0) |
| batch_size = 8 |
| num_dims = 11 |
| num_rhs = 5 |
| |
| diagonals_np = np.random.normal(size=(batch_size, 3, |
| num_dims)).astype(np.float32) |
| rhs_np = np.random.normal(size=(batch_size, num_dims, |
| num_rhs)).astype(np.float32) |
| |
| with self.session() as sess, self.test_scope(): |
| with self.assertRaisesRegex( |
| errors_impl.UnimplementedError, |
| "Current implementation does not yet support pivoting."): |
| diags = array_ops.placeholder( |
| shape=(batch_size, 3, num_dims), dtype=dtypes.float32) |
| rhs = array_ops.placeholder( |
| shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32) |
| sess.run( |
| linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=True), |
| feed_dict={ |
| diags: diagonals_np, |
| rhs: rhs_np |
| }) |
| |
| # testCaseRequiringPivotingLastRows is skipped as pivoting is not supported |
| # for now. |
| |
| # testNotInvertible is skipped as runtime error not raised for now. |
| |
| def testDiagonal(self): |
| self._test( |
| diags=[[0, 0, 0, 0], [1, 2, -1, -2], [0, 0, 0, 0]], |
| rhs=[1, 2, 3, 4], |
| expected=[1, 1, -3, -2]) |
| |
| def testUpperTriangular(self): |
| self._test( |
| diags=[[2, 4, -1, 0], [1, 3, 1, 2], [0, 0, 0, 0]], |
| rhs=[1, 6, 4, 4], |
| expected=[13, -6, 6, 2]) |
| |
| def testLowerTriangular(self): |
| self._test( |
| diags=[[0, 0, 0, 0], [2, -1, 3, 1], [0, 1, 4, 2]], |
| rhs=[4, 5, 6, 1], |
| expected=[2, -3, 6, -11]) |
| |
| def testWithTwoRightHandSides(self): |
| self._test( |
| diags=_sample_diags, |
| rhs=np.transpose([_sample_rhs, 2 * _sample_rhs]), |
| expected=np.transpose([_sample_result, 2 * _sample_result])) |
| |
| def testBatching(self): |
| self._test( |
| diags=np.array([_sample_diags, -_sample_diags]), |
| rhs=np.array([_sample_rhs, 2 * _sample_rhs]), |
| expected=np.array([_sample_result, -2 * _sample_result])) |
| |
| def testWithTwoBatchingDimensions(self): |
| self._test( |
| diags=np.array([[_sample_diags, -_sample_diags, _sample_diags], |
| [-_sample_diags, _sample_diags, -_sample_diags]]), |
| rhs=np.array([[_sample_rhs, 2 * _sample_rhs, 3 * _sample_rhs], |
| [4 * _sample_rhs, 5 * _sample_rhs, 6 * _sample_rhs]]), |
| expected=np.array( |
| [[_sample_result, -2 * _sample_result, 3 * _sample_result], |
| [-4 * _sample_result, 5 * _sample_result, -6 * _sample_result]])) |
| |
| def testBatchingAndTwoRightHandSides(self): |
| rhs = np.transpose([_sample_rhs, 2 * _sample_rhs]) |
| expected_result = np.transpose([_sample_result, 2 * _sample_result]) |
| self._test( |
| diags=np.array([_sample_diags, -_sample_diags]), |
| rhs=np.array([rhs, 2 * rhs]), |
| expected=np.array([expected_result, -2 * expected_result])) |
| |
| def testSequenceFormat(self): |
| self._testWithDiagonalLists( |
| diags=[[2, 1, 4], [1, 3, 2, 2], [1, -1, 1]], |
| rhs=[1, 2, 3, 4], |
| expected=[-9, 5, -4, 4], |
| diags_format="sequence") |
| |
| def testSequenceFormatWithDummyElements(self): |
| dummy = 20 # Should be ignored by the solver. |
| self._testWithDiagonalLists( |
| diags=[ |
| [2, 1, 4, dummy], |
| [1, 3, 2, 2], |
| [dummy, 1, -1, 1], |
| ], |
| rhs=[1, 2, 3, 4], |
| expected=[-9, 5, -4, 4], |
| diags_format="sequence") |
| |
| def testSequenceFormatWithBatching(self): |
| self._testWithDiagonalLists( |
| diags=[[[2, 1, 4], [-2, -1, -4]], [[1, 3, 2, 2], [-1, -3, -2, -2]], |
| [[1, -1, 1], [-1, 1, -1]]], |
| rhs=[[1, 2, 3, 4], [1, 2, 3, 4]], |
| expected=[[-9, 5, -4, 4], [9, -5, 4, -4]], |
| diags_format="sequence") |
| |
| def testMatrixFormat(self): |
| self._test( |
| diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]], |
| rhs=[1, 2, 3, 4], |
| expected=[-9, 5, -4, 4], |
| diags_format="matrix") |
| |
| def testMatrixFormatWithMultipleRightHandSides(self): |
| self._test( |
| diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]], |
| rhs=[[1, -1], [2, -2], [3, -3], [4, -4]], |
| expected=[[-9, 9], [5, -5], [-4, 4], [4, -4]], |
| diags_format="matrix") |
| |
| def testMatrixFormatWithBatching(self): |
| self._test( |
| diags=[[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]], |
| [[-1, -2, 0, 0], [-1, -3, -1, 0], [0, 1, -2, -4], [0, 0, -1, |
| -2]]], |
| rhs=[[1, 2, 3, 4], [1, 2, 3, 4]], |
| expected=[[-9, 5, -4, 4], [9, -5, 4, -4]], |
| diags_format="matrix") |
| |
| def testRightHandSideAsColumn(self): |
| self._test( |
| diags=_sample_diags, |
| rhs=np.transpose([_sample_rhs]), |
| expected=np.transpose([_sample_result]), |
| diags_format="compact") |
| |
| def testTransposeRhs(self): |
| self._test( |
| diags=_sample_diags, |
| rhs=np.array([_sample_rhs, 2 * _sample_rhs]), |
| expected=np.array([_sample_result, 2 * _sample_result]).T, |
| transpose_rhs=True) |
| |
| # testConjugateRhs is skipped as complex type is not yet supported. |
| |
| # testAjointRhs is skipped as complex type is not yet supported |
| |
| def testTransposeRhsWithRhsAsVector(self): |
| self._test( |
| diags=_sample_diags, |
| rhs=_sample_rhs, |
| expected=_sample_result, |
| transpose_rhs=True) |
| |
| # testConjugateRhsWithRhsAsVector is skipped as complex type is not yet |
| # supported. |
| |
| def testTransposeRhsWithRhsAsVectorAndBatching(self): |
| self._test( |
| diags=np.array([_sample_diags, -_sample_diags]), |
| rhs=np.array([_sample_rhs, 2 * _sample_rhs]), |
| expected=np.array([_sample_result, -2 * _sample_result]), |
| transpose_rhs=True) |
| |
| # Gradient tests |
| |
| def _gradientTest( |
| self, |
| diags, |
| rhs, |
| y, # output = reduce_sum(y * tridiag_solve(diags, rhs)) |
| expected_grad_diags, # expected gradient of output w.r.t. diags |
| expected_grad_rhs, # expected gradient of output w.r.t. rhs |
| diags_format="compact", |
| transpose_rhs=False, |
| feed_dict=None): |
| expected_grad_diags = np.array(expected_grad_diags).astype(np.float32) |
| expected_grad_rhs = np.array(expected_grad_rhs).astype(np.float32) |
| with self.session() as sess, self.test_scope(): |
| diags = _tfconst(diags) |
| rhs = _tfconst(rhs) |
| y = _tfconst(y) |
| |
| x = linalg_impl.tridiagonal_solve( |
| diags, |
| rhs, |
| diagonals_format=diags_format, |
| transpose_rhs=transpose_rhs, |
| conjugate_rhs=False, |
| partial_pivoting=False) |
| |
| res = math_ops.reduce_sum(x * y) |
| actual_grad_diags = sess.run( |
| gradient_ops.gradients(res, diags)[0], feed_dict=feed_dict) |
| actual_rhs_diags = sess.run( |
| gradient_ops.gradients(res, rhs)[0], feed_dict=feed_dict) |
| self.assertAllClose(expected_grad_diags, actual_grad_diags) |
| self.assertAllClose(expected_grad_rhs, actual_rhs_diags) |
| |
| def testGradientSimple(self): |
| self._gradientTest( |
| diags=_sample_diags, |
| rhs=_sample_rhs, |
| y=[1, 3, 2, 4], |
| expected_grad_diags=[[-5, 0, 4, 0], [9, 0, -4, -16], [0, 0, 5, 16]], |
| expected_grad_rhs=[1, 0, -1, 4]) |
| |
| def testGradientWithMultipleRhs(self): |
| self._gradientTest( |
| diags=_sample_diags, |
| rhs=[[1, 2], [2, 4], [3, 6], [4, 8]], |
| y=[[1, 5], [2, 6], [3, 7], [4, 8]], |
| expected_grad_diags=[[-20, 28, -60, 0], [36, -35, 60, 80], |
| [0, 63, -75, -80]], |
| expected_grad_rhs=[[0, 2], [1, 3], [1, 7], [0, -10]]) |
| |
| def _makeDataForGradientWithBatching(self): |
| y = np.array([1, 3, 2, 4]).astype(np.float32) |
| grad_diags = np.array([[-5, 0, 4, 0], [9, 0, -4, -16], |
| [0, 0, 5, 16]]).astype(np.float32) |
| grad_rhs = np.array([1, 0, -1, 4]).astype(np.float32) |
| |
| diags_batched = np.array( |
| [[_sample_diags, 2 * _sample_diags, 3 * _sample_diags], |
| [4 * _sample_diags, 5 * _sample_diags, |
| 6 * _sample_diags]]).astype(np.float32) |
| rhs_batched = np.array([[_sample_rhs, -_sample_rhs, _sample_rhs], |
| [-_sample_rhs, _sample_rhs, |
| -_sample_rhs]]).astype(np.float32) |
| y_batched = np.array([[y, y, y], [y, y, y]]).astype(np.float32) |
| expected_grad_diags_batched = np.array( |
| [[grad_diags, -grad_diags / 4, grad_diags / 9], |
| [-grad_diags / 16, grad_diags / 25, |
| -grad_diags / 36]]).astype(np.float32) |
| expected_grad_rhs_batched = np.array( |
| [[grad_rhs, grad_rhs / 2, grad_rhs / 3], |
| [grad_rhs / 4, grad_rhs / 5, grad_rhs / 6]]).astype(np.float32) |
| |
| return (y_batched, diags_batched, rhs_batched, expected_grad_diags_batched, |
| expected_grad_rhs_batched) |
| |
| def testGradientWithBatchDims(self): |
| y, diags, rhs, expected_grad_diags, expected_grad_rhs = ( |
| self._makeDataForGradientWithBatching()) |
| |
| self._gradientTest( |
| diags=diags, |
| rhs=rhs, |
| y=y, |
| expected_grad_diags=expected_grad_diags, |
| expected_grad_rhs=expected_grad_rhs) |
| |
| # testGradientWithUnknownShapes is skipped as shapes should be fully known. |
| |
| def _assertRaises(self, diags, rhs, diags_format="compact"): |
| with self.assertRaises(ValueError): |
| linalg_impl.tridiagonal_solve(diags, rhs, diags_format) |
| |
| # Invalid input shapes |
| def testInvalidShapesCompactFormat(self): |
| |
| def test_raises(diags_shape, rhs_shape): |
| self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact") |
| |
| test_raises((5, 4, 4), (5, 4)) |
| test_raises((5, 3, 4), (4, 5)) |
| test_raises((5, 3, 4), (5)) |
| test_raises((5), (5, 4)) |
| |
| def testInvalidShapesSequenceFormat(self): |
| |
| def test_raises(diags_tuple_shapes, rhs_shape): |
| diagonals = tuple(_tf_ones(shape) for shape in diags_tuple_shapes) |
| self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence") |
| |
| test_raises(((5, 4), (5, 4)), (5, 4)) |
| test_raises(((5, 4), (5, 4), (5, 6)), (5, 4)) |
| test_raises(((5, 3), (5, 4), (5, 6)), (5, 4)) |
| test_raises(((5, 6), (5, 4), (5, 3)), (5, 4)) |
| test_raises(((5, 4), (7, 4), (5, 4)), (5, 4)) |
| test_raises(((5, 4), (7, 4), (5, 4)), (3, 4)) |
| |
| def testInvalidShapesMatrixFormat(self): |
| |
| def test_raises(diags_shape, rhs_shape): |
| self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix") |
| |
| test_raises((5, 4, 7), (5, 4)) |
| test_raises((5, 4, 4), (3, 4)) |
| test_raises((5, 4, 4), (5, 3)) |
| |
| # Tests involving placeholder with an unknown dimension are all skipped. |
| # Dimensions have to be all known statically. |
| |
| |
| if __name__ == "__main__": |
| test.main() |