blob: dcde84605e0f5a606267b39fa7de94899dd8c504 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tridiagonal_matmul_ops."""
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops as random
from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.platform import googletest
class TridiagonalMatMulOpsTest(xla_test.XLATestCase):
@def_function.function(jit_compile=True)
def _jit_tridiagonal_matmul(self, diagonals, rhs, diagonals_format):
return linalg_impl.tridiagonal_matmul(
diagonals, rhs, diagonals_format=diagonals_format)
@def_function.function(jit_compile=False)
def _tridiagonal_matmul(self, diagonals, rhs, diagonals_format):
return linalg_impl.tridiagonal_matmul(
diagonals, rhs, diagonals_format=diagonals_format)
def _testAllFormats(self,
superdiag,
maindiag,
subdiag,
rhs,
dtype=dtypes.float64):
superdiag_extended = np.pad(superdiag, [0, 1], 'constant')
subdiag_extended = np.pad(subdiag, [1, 0], 'constant')
diags_compact = np.stack([superdiag_extended, maindiag, subdiag_extended])
diags_matrix = np.diag(superdiag, 1) + np.diag(maindiag, 0) + np.diag(
subdiag, -1)
with self.test_scope():
diags_sequence_op = (constant_op.constant(superdiag_extended, dtype),
constant_op.constant(maindiag, dtype),
constant_op.constant(subdiag_extended, dtype))
diags_compact_op = constant_op.constant(diags_compact, dtype)
diags_matrix_op = constant_op.constant(diags_matrix, dtype)
rhs_op = constant_op.constant(rhs, dtype)
rhs_batch_op = array_ops_stack.stack([rhs_op, 2 * rhs_op])
diags_compact_batch_op = array_ops_stack.stack(
[diags_compact_op, 2 * diags_compact_op])
diags_matrix_batch_op = array_ops_stack.stack(
[diags_matrix_op, 2 * diags_matrix_op])
diags_sequence_batch_op = [
array_ops_stack.stack([x, 2 * x]) for x in diags_sequence_op
]
results = [
self._jit_tridiagonal_matmul(
diags_sequence_op, rhs_op, diagonals_format='sequence'),
self._jit_tridiagonal_matmul(
diags_compact_op, rhs_op, diagonals_format='compact'),
self._jit_tridiagonal_matmul(
diags_matrix_op, rhs_op, diagonals_format='matrix')
]
results_batch = [
self._jit_tridiagonal_matmul(
diags_sequence_batch_op,
rhs_batch_op,
diagonals_format='sequence'),
self._jit_tridiagonal_matmul(
diags_compact_batch_op, rhs_batch_op, diagonals_format='compact'),
self._jit_tridiagonal_matmul(
diags_matrix_batch_op, rhs_batch_op, diagonals_format='matrix')
]
expected = self._tridiagonal_matmul(
diags_sequence_op, rhs_op, diagonals_format='sequence')
expected_batch = np.stack([expected, 4 * expected])
for result in results:
self.assertAllClose(result, expected)
for result in results_batch:
self.assertAllClose(result, expected_batch)
def _makeTridiagonalMatrix(self, superdiag, maindiag, subdiag):
super_pad = [[0, 0], [0, 1], [1, 0]]
sub_pad = [[0, 0], [1, 0], [0, 1]]
super_part = array_ops.pad(array_ops.matrix_diag(superdiag), super_pad)
main_part = array_ops.matrix_diag(maindiag)
sub_part = array_ops.pad(array_ops.matrix_diag(subdiag), sub_pad)
return super_part + main_part + sub_part
def _gradientTest(self, diags, rhs, dtype=dtypes.float64):
def reference_matmul(diags, rhs):
matrix = self._makeTridiagonalMatrix(diags[..., 0, :-1], diags[..., 1, :],
diags[..., 2, 1:])
return math_ops.matmul(matrix, rhs)
with self.session(), self.test_scope():
diags = constant_op.constant(diags, dtype=dtype)
rhs = constant_op.constant(rhs, dtype=dtype)
grad_reference, _ = gradient_checker_v2.compute_gradient(
reference_matmul, [diags, rhs])
grad_theoretical, grad_numerical = gradient_checker_v2.compute_gradient(
linalg_impl.tridiagonal_matmul, [diags, rhs])
self.assertAllClose(grad_theoretical, grad_numerical)
self.assertAllClose(grad_theoretical, grad_reference)
def test2x2(self):
self._testAllFormats(
superdiag=[1], maindiag=[2, 3], subdiag=[4], rhs=[[2, 1], [4, 3]])
def test3x3(self):
for dtype in [dtypes.float32, dtypes.float64]:
self._testAllFormats(
superdiag=[1, 2],
maindiag=[1, 2, 1],
subdiag=[2, 1],
rhs=[[1, 1], [2, 2], [3, 3]],
dtype=dtype)
def testNDimensional(self):
with self.test_scope():
shape = [77, 10, 3, 5, 7]
superdiag = random.stateless_random_uniform(
shape=shape[:-1], dtype=dtypes.float32, seed=[5, 10])
maindiag = random.stateless_random_uniform(
shape=shape[:-1], dtype=dtypes.float32, seed=[5, 11])
subdiag = random.stateless_random_uniform(
shape=shape[:-1], dtype=dtypes.float32, seed=[5, 12])
rhs = random.stateless_random_uniform(
shape=shape, dtype=dtypes.float32, seed=[5, 13])
expected = self._tridiagonal_matmul((superdiag, maindiag, subdiag),
rhs,
diagonals_format='sequence')
real = self._jit_tridiagonal_matmul((superdiag, maindiag, subdiag),
rhs,
diagonals_format='sequence')
self.assertAllClose(expected, real)
def testGradientSmall(self):
self._gradientTest([[[1, 2, 0], [1, 2, 3], [0, 1, 2]]],
[[[1, 2], [3, 4], [5, 6]]],
dtype=dtypes.float64)
# testComplex is skipped as complex type is not yet supported.
if __name__ == '__main__':
ops.enable_eager_execution()
googletest.main()