blob: adb89e50cd5f1d56286573397a55bc1efa4bac23 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.ops.special_math_ops."""
from absl.testing import parameterized
import numpy as np
import opt_einsum
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
class LBetaTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_one_dimensional_arg(self):
# Should evaluate to 1 and 1/2.
x_one = [1, 1.]
x_one_half = [2, 1.]
with self.session():
self.assertAllClose(
1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one))))
self.assertAllClose(
0.5, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
self.assertEqual([], special_math_ops.lbeta(x_one).get_shape())
@test_util.run_deprecated_v1
def test_one_dimensional_arg_dynamic(self):
# Should evaluate to 1 and 1/2.
x_one = [1, 1.]
x_one_half = [2, 1.]
with self.session():
ph = array_ops.placeholder(dtypes.float32)
beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one}))
self.assertAllClose(0.5,
beta_ph.eval(feed_dict={ph: x_one_half}))
@test_util.run_deprecated_v1
def test_four_dimensional_arg_with_partial_shape_dynamic(self):
x_ = np.ones((3, 2, 3, 4))
# Gamma(1) = 0! = 1
# Gamma(1 + 1 + 1 + 1) = Gamma(4) = 3! = 6
# ==> Beta([1, 1, 1, 1])
# = Gamma(1) * Gamma(1) * Gamma(1) * Gamma(1) / Gamma(1 + 1 + 1 + 1)
# = 1 / 6
expected_beta_x = 1 / 6 * np.ones((3, 2, 3))
with self.session():
x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None])
beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph))
self.assertAllClose(expected_beta_x,
beta_ph.eval(feed_dict={x_ph: x_}))
@test_util.run_in_graph_and_eager_modes
def test_two_dimensional_arg(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.session():
self.assertAllClose(
[0.5, 0.5],
self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape())
@test_util.run_deprecated_v1
def test_two_dimensional_arg_dynamic(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.session():
ph = array_ops.placeholder(dtypes.float32)
beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
self.assertAllClose([0.5, 0.5],
beta_ph.eval(feed_dict={ph: x_one_half}))
@test_util.run_in_graph_and_eager_modes
def test_two_dimensional_proper_shape(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.session():
self.assertAllClose(
[0.5, 0.5],
self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
self.assertEqual(
(2,),
self.evaluate(array_ops.shape(special_math_ops.lbeta(x_one_half))))
self.assertEqual(
tensor_shape.TensorShape([2]),
special_math_ops.lbeta(x_one_half).get_shape())
@test_util.run_in_graph_and_eager_modes
def test_complicated_shape(self):
with self.session():
x = ops.convert_to_tensor(np.random.rand(3, 2, 2))
self.assertAllEqual(
(3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x))))
self.assertEqual(
tensor_shape.TensorShape([3, 2]),
special_math_ops.lbeta(x).get_shape())
@test_util.run_in_graph_and_eager_modes
def test_length_1_last_dimension_results_in_one(self):
# If there is only one coefficient, the formula still works, and we get one
# as the answer, always.
x_a = [5.5]
x_b = [0.1]
with self.session():
self.assertAllClose(
1,
self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a))),
rtol=3e-6)
self.assertAllClose(
1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_b))))
self.assertEqual((), special_math_ops.lbeta(x_a).get_shape())
@test_util.run_in_graph_and_eager_modes
def test_empty_rank1_returns_negative_infinity(self):
with self.session():
x = constant_op.constant([], shape=[0])
lbeta_x = special_math_ops.lbeta(x)
expected_result = constant_op.constant(-np.inf, shape=())
self.assertAllEqual(self.evaluate(expected_result),
self.evaluate(lbeta_x))
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
@test_util.run_in_graph_and_eager_modes
def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self):
with self.session():
event_size = 0
for batch_size in [0, 1, 2]:
x = constant_op.constant([], shape=[batch_size, event_size])
lbeta_x = special_math_ops.lbeta(x)
expected_result = constant_op.constant(-np.inf, shape=[batch_size])
self.assertAllEqual(self.evaluate(expected_result),
self.evaluate(lbeta_x))
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
@test_util.run_in_graph_and_eager_modes
def test_empty_rank2_with_zero_batch_dim_returns_empty(self):
with self.session():
batch_size = 0
for event_size in [0, 1, 2]:
x = constant_op.constant([], shape=[batch_size, event_size])
lbeta_x = special_math_ops.lbeta(x)
expected_result = constant_op.constant([], shape=[batch_size])
self.assertAllEqual(self.evaluate(expected_result),
self.evaluate(lbeta_x))
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
@test_util.run_all_in_graph_and_eager_modes
class DawsnTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_dawsn_boundary(self):
self.assertAllClose(0., special_math_ops.dawsn(0.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.dawsn(np.nan))))
@parameterized.parameters(np.float32, np.float64)
def test_dawsn_odd(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.dawsn(x)),
self.evaluate(-special_math_ops.dawsn(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_dawsn_small(self, dtype):
x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.dawsn(x), self.evaluate(special_math_ops.dawsn(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_dawsn_larger(self, dtype):
x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.dawsn(x), self.evaluate(special_math_ops.dawsn(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_dawsn_gradient(self):
inputs = [np.random.uniform(-50., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.dawsn, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
@test_util.run_all_in_graph_and_eager_modes
class ExpintTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_expint_boundary(self):
self.assertAllClose(-np.inf, special_math_ops.expint(0.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.expint(np.nan))))
# Check that the domain of definition is [0, inf)
self.assertTrue(
np.all(
np.isnan(
self.evaluate(
special_math_ops.expint(
np.random.uniform(-20., -1., size=int(1e3)))))))
@parameterized.parameters(np.float32, np.float64)
def test_expint_small(self, dtype):
x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.expi(x), self.evaluate(special_math_ops.expint(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_expint_larger(self, dtype):
x = np.random.uniform(1., 50., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.expi(x), self.evaluate(special_math_ops.expint(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_expint_gradient(self):
inputs = [np.random.uniform(1., 10., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.expint, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 5e-3)
@test_util.run_all_in_graph_and_eager_modes
class FresnelCosTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_fresnel_cos_boundary(self):
self.assertAllClose(0., special_math_ops.fresnel_cos(0.))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.fresnel_cos(np.nan))))
@parameterized.parameters(np.float32, np.float64)
def test_fresnel_cos_odd(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.fresnel_cos(x)),
self.evaluate(-special_math_ops.fresnel_cos(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_fresnel_cos_small(self, dtype):
x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.fresnel(x)[1], self.evaluate(special_math_ops.fresnel_cos(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_fresnel_cos_larger(self, dtype):
x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.fresnel(x)[1],
self.evaluate(special_math_ops.fresnel_cos(x)),
rtol=1e-5)
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_fresnel_cos_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.fresnel_cos, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 5e-3)
@test_util.run_all_in_graph_and_eager_modes
class FresnelSinTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_fresnel_sin_boundary(self):
self.assertAllClose(0., special_math_ops.fresnel_sin(0.))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.fresnel_sin(np.nan))))
@parameterized.parameters(np.float32, np.float64)
def test_fresnel_sin_odd(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.fresnel_sin(x)),
self.evaluate(-special_math_ops.fresnel_sin(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_fresnel_sin_small(self, dtype):
x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.fresnel(x)[0], self.evaluate(special_math_ops.fresnel_sin(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_fresnel_sin_larger(self, dtype):
x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.fresnel(x)[0],
self.evaluate(special_math_ops.fresnel_sin(x)),
rtol=1e-5)
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_fresnel_sin_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.fresnel_sin, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 5e-3)
@test_util.run_all_in_graph_and_eager_modes
class SpenceTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_spence_boundary(self):
self.assertAllClose(np.pi**2 / 6., special_math_ops.spence(0.))
self.assertAllClose(0., special_math_ops.spence(1.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.spence(np.nan))))
# Check that the domain of definition is [0, inf)
self.assertTrue(
np.all(
np.isnan(
self.evaluate(
special_math_ops.spence(
np.random.uniform(-20., -1., size=int(1e3)))))))
@parameterized.parameters(np.float32, np.float64)
def test_spence_small(self, dtype):
x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.spence(x), self.evaluate(special_math_ops.spence(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_spence_larger(self, dtype):
x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.spence(x), self.evaluate(special_math_ops.spence(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_spence_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.spence, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_spence_gradient_at_one(self):
analytical, _ = gradient_checker_v2.compute_gradient(
special_math_ops.spence, [1.])
self.assertAllClose([[[-1.]]], analytical)
@test_util.run_all_in_graph_and_eager_modes
class BesselTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_besseli_boundary(self):
self.assertAllClose(1., special_math_ops.bessel_i0(0.))
self.assertAllClose(1., special_math_ops.bessel_i0e(0.))
self.assertAllClose(0., special_math_ops.bessel_i1(0.))
self.assertAllClose(0., special_math_ops.bessel_i1e(0.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_i0(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_i0e(np.nan))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_i1(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_i1e(np.nan))))
@test_util.run_in_graph_and_eager_modes
def test_besselj_boundary(self):
self.assertAllClose(1., special_math_ops.bessel_j0(0.))
self.assertAllClose(0., special_math_ops.bessel_j1(0.))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_j0(np.nan))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_j1(np.nan))))
@test_util.run_in_graph_and_eager_modes
def test_besselk_boundary(self):
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k0(0.))))
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k0e(0.))))
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k1(0.))))
self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k1e(0.))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_k0(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_k0e(np.nan))))
self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_k1(np.nan))))
self.assertTrue(
np.isnan(self.evaluate(special_math_ops.bessel_k1e(np.nan))))
@parameterized.parameters(np.float32, np.float64)
def test_i0j0_even(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i0(x)),
self.evaluate(special_math_ops.bessel_i0(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i0e(x)),
self.evaluate(special_math_ops.bessel_i0e(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_j0(x)),
self.evaluate(special_math_ops.bessel_j0(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_i1j1_odd(self, dtype):
x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i1(x)),
self.evaluate(-special_math_ops.bessel_i1(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_i1e(x)),
self.evaluate(-special_math_ops.bessel_i1e(-x)))
self.assertAllClose(
self.evaluate(special_math_ops.bessel_j1(x)),
self.evaluate(-special_math_ops.bessel_j1(-x)))
@parameterized.parameters(np.float32, np.float64)
def test_besseli_small(self, dtype):
x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.i0(x), self.evaluate(special_math_ops.bessel_i0(x)))
self.assertAllClose(
special.i1(x), self.evaluate(special_math_ops.bessel_i1(x)))
self.assertAllClose(
special.i0e(x), self.evaluate(special_math_ops.bessel_i0e(x)))
self.assertAllClose(
special.i1e(x), self.evaluate(special_math_ops.bessel_i1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselj_small(self, dtype):
x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.j0(x), self.evaluate(special_math_ops.bessel_j0(x)))
self.assertAllClose(
special.j1(x), self.evaluate(special_math_ops.bessel_j1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselk_small(self, dtype):
x = np.random.uniform(np.finfo(dtype).eps, 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.k0(x), self.evaluate(special_math_ops.bessel_k0(x)))
self.assertAllClose(
special.k0e(x), self.evaluate(special_math_ops.bessel_k0e(x)))
self.assertAllClose(
special.k1(x), self.evaluate(special_math_ops.bessel_k1(x)))
self.assertAllClose(
special.k1e(x), self.evaluate(special_math_ops.bessel_k1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_bessely_small(self, dtype):
x = np.random.uniform(np.finfo(dtype).eps, 1., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.y0(x), self.evaluate(special_math_ops.bessel_y0(x)))
self.assertAllClose(
special.y1(x), self.evaluate(special_math_ops.bessel_y1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besseli_larger(self, dtype):
x = np.random.uniform(1., 20., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.i0e(x), self.evaluate(special_math_ops.bessel_i0e(x)))
self.assertAllClose(
special.i1e(x), self.evaluate(special_math_ops.bessel_i1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselj_larger(self, dtype):
x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.j0(x), self.evaluate(special_math_ops.bessel_j0(x)))
self.assertAllClose(
special.j1(x), self.evaluate(special_math_ops.bessel_j1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_besselk_larger(self, dtype):
x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.k0(x), self.evaluate(special_math_ops.bessel_k0(x)))
self.assertAllClose(
special.k0e(x), self.evaluate(special_math_ops.bessel_k0e(x)))
self.assertAllClose(
special.k1(x), self.evaluate(special_math_ops.bessel_k1(x)))
self.assertAllClose(
special.k1e(x), self.evaluate(special_math_ops.bessel_k1e(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
@parameterized.parameters(np.float32, np.float64)
def test_bessely_larger(self, dtype):
x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
try:
from scipy import special # pylint: disable=g-import-not-at-top
self.assertAllClose(
special.y0(x), self.evaluate(special_math_ops.bessel_y0(x)))
self.assertAllClose(
special.y1(x), self.evaluate(special_math_ops.bessel_y1(x)))
except ImportError as e:
tf_logging.warn('Cannot test special functions: %s' % str(e))
def test_besseli_gradient(self):
inputs = [np.random.uniform(-10., 10., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-3)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i0e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-3)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_i1e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_besselj_gradient(self):
inputs = [np.random.uniform(-50., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_j0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_j1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_besselk_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k0e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_k1e, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
def test_bessely_gradient(self):
inputs = [np.random.uniform(1., 50., size=int(1e2))]
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_y0, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
analytical, numerical = gradient_checker_v2.compute_gradient(
special_math_ops.bessel_y1, inputs)
self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)
@test_util.run_all_in_graph_and_eager_modes
@test_util.run_all_without_tensor_float_32(
'Tests einsum, which sometimes does a matmul with cuBLAS')
class EinsumTest(test.TestCase):
def _check(self, s, *input_shapes, **kwargs):
dtype = kwargs.pop('dtype', np.float32)
r = np.random.RandomState(0)
inputs = []
for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
a = np.einsum(s, *inputs)
b = self.evaluate(special_math_ops.einsum(s, *input_tensors))
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
def test_invalid_keyword_arguments(self):
r = np.random.RandomState(0)
a = array_ops.placeholder_with_default(r.randn(2, 3), shape=(2, 3))
b = array_ops.placeholder_with_default(r.randn(3, 4), shape=(3, 4))
with self.assertRaises(TypeError):
_ = special_math_ops.einsum(
'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2')
def test_unary(self):
self._check('a', (3,))
self._check('aa', (3, 3))
self._check('ab->', (3, 3))
self._check('ab->ab', (3, 3))
self._check('abc->b', (3, 4, 5))
self._check('abc->ca', (3, 4, 5))
self._check('abc->cab', (3, 4, 5))
# Empty cases.
self._check('', ())
self._check('->', ())
# Repeated indices cases.
self._check('aa->', (3, 3))
self._check('aa->a', (3, 3))
self._check('aaa->', (3, 3, 3))
self._check('aaa->a', (3, 3, 3))
self._check('aab->a', (3, 3, 4))
self._check('aabcc->a', (3, 3, 5, 4, 4))
self._check('aabcc->ac', (3, 3, 5, 4, 4))
self._check('aabcd->ad', (3, 3, 5, 4, 4))
def test_unary_ellipsis(self):
self._check('...->', ())
self._check('...ijk->...ki', (3, 4, 5))
self._check('...ijk->...ki', (1, 3, 4, 5))
self._check('...ijk->...ki', (2, 2, 3, 4, 5))
self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose
self._check('...ij->...', (5, 2, 3)) # batch sum
self._check('...->...', ())
self._check('->...', ())
# Repeated indices.
self._check('i...ii->...i', (3, 2, 3, 3))
self._check('i...i->i...', (2, 2))
self._check('i...i->', (2, 2))
self._check('i...i->...', (2, 5, 1, 2))
self._check('i...i->i...', (2, 1, 2))
self._check('i...i->i...', (2, 3, 4, 5, 2))
def test_binary_simple(self):
# Binary cases in XLA mode must have either (a) each index appearing exactly
# once in both the inputs (batch or contraction index), or (b) appearing
# exactly once in an input and in the output (free index).
self._check(',->', (), ())
self._check('a,a->', (3,), (3,))
self._check('a,a->a', (3,), (3,))
self._check('ab,b->a', (3, 4), (4,))
self._check('ab,ab->', (3, 4), (3, 4))
self._check('ab,bc->ac', (3, 4), (4, 5))
self._check('nij,jk->nik', (5, 2, 3), (3, 4))
self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
self._check('sa,shb->shab', (2, 1), (2, 3, 4))
# Infer the output subscripts.
self._check('ab,b', (3, 4), (4,))
self._check('cab,b', (1, 3, 4), (4,))
def test_reduced_indices(self):
self._check('ba,b->', (3, 2), (3,))
self._check('ab,ab->', (3, 4), (3, 4))
def test_repeated_indices(self):
# Repeated indices.
self._check('ijj,k->ik', (2, 3, 3), (4,))
self._check('aba,a->b', (3, 4, 3), (3,))
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
self._check('aab,bc->ac', (2, 2, 3), (3, 4))
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
def test_binary_ellipsis(self):
# Batch matmul with ellipsis but without broadcasting.
self._check('...mk,...kn->...mn', (5, 1, 2, 3), (5, 1, 3, 4))
# Empty batch dimensions.
self._check('...mk,...kn->...mn', (2, 3), (3, 4))
# Tensor contraction with transpose.
self._check('...ija,aijb...->ba...ij', (1, 2, 2, 3, 1), (1, 2, 3, 4, 1, 2))
# Output subscripts may omit ellipsis when batch shape is empty.
self._check('...mk,...kn->mn', (2, 3), (3, 4))
self._check('...mk,kn->mn', (2, 3), (3, 4))
self._check('mk,...kn->mn', (2, 3), (3, 4))
self._check('...,...->...', (2, 3), (2, 3)) # hadamard product
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
def test_broadcasting(self):
# Batch matmul with broadcasting.
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
# Broadcasting with repeated indices.
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
# Following 2 from https://stackoverflow.com/a/19203475/1611416
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
def test_dtypes(self):
dtypes = [np.float64, np.float32, np.complex64, np.complex128]
for dtype in dtypes:
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)
def test_multiple_inputs(self):
self._check('ijk,ijl,ikl->i', (1, 2, 3), (1, 2, 4), (1, 3, 4))
self._check('i,ijk,j->k', (1,), (1, 2, 4), (2,))
self._check('ij,ij,jk,kl->il', (1, 2), (1, 2), (2, 3), (3, 4))
# Tests from dask.
self._check('a,b,c', (5,), (7,), (9,))
self._check('ab,ab,c->c', (5, 6), (5, 6), (2,))
@test_util.disable_xla('b/131919749')
def test_placeholder(self):
def check(equation, *input_and_placeholder_shapes):
r = np.random.RandomState(0)
inputs = []
input_placeholders = []
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
input_np = np.array(r.randn(*actual_shape))
inputs.append(input_np)
input_placeholders.append(
array_ops.placeholder_with_default(input_np, placeholder_shape))
a = np.einsum(equation, *inputs)
b = self.evaluate(special_math_ops.einsum(equation, *input_placeholders))
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)),
((9, 3, 4, 7), (None, None, 4, None)))
check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)),
((4, 3), (None, 3)))
# Ellipsis with unknown rank.
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
def test_numpy_input(self):
# In addition to Tensors, we also support raw numpy arrays as inputs.
r = np.random.RandomState(0)
s = 'ijk,ijl,ikl->i'
x = r.randn(1, 2, 3)
y = r.randn(1, 2, 4)
z = r.randn(1, 3, 4)
a = np.einsum(s, x, y, z)
b = self.evaluate(special_math_ops.einsum(s, x, y, z))
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
def test_long_cases(self):
cases = [
'efc,dbc,acf,fd->abe',
'ea,fb,gc,hd,abcd->efgh',
'abhe,hidj,jgba,hiab,gab->ed',
# Cases with whitespace.
'efc, dbc, acf, fd -> abe',
'abhe, hidj, jgba, hiab, gab',
# Repeated equations for cache hit on the opt_einsum call.
'ea,fb,abcd,gc,hd->efgh',
'ea,fb,abcd,gc,hd->efgh',
]
dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
for equation in cases:
inputs = equation.split('->')[0].replace(' ', '')
input_shapes = []
for input_str in inputs.split(','):
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
self._check(equation, *input_shapes)
def test_opt_einsum_cached(self):
# Checks call_count to opt_einsum which are only reflected in eager mode.
if not context.executing_eagerly():
return
input_1 = ('ijk,ijl,ikl->i', (1, 2, 3), (1, 2, 4), (1, 3, 4))
input_2 = ('ij,ij,jk,kl->il', (1, 2), (1, 2), (2, 3), (3, 4))
with test.mock.patch.object(
opt_einsum, 'contract_path',
wraps=opt_einsum.contract_path) as mock_contract_path:
# explicitly clear the lru_cache contents for the method
# special_math_ops.get_opt_einsum_contract_path
# We need to do this because other tests in this file invoke that method
# with the same input args (as input_1 and input_2 above), and if
# those tests run before this test, then the call_count for the method
# mock_contract_path will not increment.
special_math_ops._get_opt_einsum_contract_path.cache_clear()
self.assertEqual(mock_contract_path.call_count, 0)
self._check(*input_1)
self.assertEqual(mock_contract_path.call_count, 1)
# The same input results in no extra call if we're caching the
# opt_einsum.contract_path call. We only cache in Python3.
self._check(*input_1)
self.assertEqual(mock_contract_path.call_count, 1)
# New input results in another call to opt_einsum.
self._check(*input_2)
self.assertEqual(mock_contract_path.call_count, 2)
# No more extra calls as the inputs should be cached.
self._check(*input_1)
self._check(*input_2)
self._check(*input_1)
self.assertEqual(mock_contract_path.call_count, 2)
@test_util.disable_xla('b/131919749')
def test_long_cases_with_repeated_labels(self):
cases = [
# Tests from dask.
'fdf,cdd,ccd,afe->ae',
'fff,fae,bef,def->abd',
]
dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
for equation in cases:
inputs = equation.split('->')[0].replace(' ', '')
input_shapes = []
for input_str in inputs.split(','):
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
self._check(equation, *input_shapes)
@test_util.disable_xla('b/131919749')
@test_util.run_in_graph_and_eager_modes
def test_invalid_equation(self):
r = np.random.RandomState(0)
cases = [
# invalid equation format.
('a0->a', r.randn(5, 3)),
('a->a,a', r.randn(5)),
('a->a->a', r.randn(5)),
('ijk ijk', r.randn(1, 2, 3), r.randn(1, 2, 3)),
('ij.jk->ik', r.randn(2, 3), r.randn(3, 4)),
# output label not present in input.
('a->b', r.randn(5)),
('ij,jk->im', r.randn(2, 3), r.randn(3, 4)),
# wrong shape.
('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
# inconsistent dimensions.
('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
# output has repeated subscripts.
('ij,jk->iik', r.randn(2, 3), r.randn(3, 4)),
# too many ellipses
('...ij...,jk...->ik...', r.randn(2, 3), r.randn(3, 4)),
('...ij,jk...->...ik...', r.randn(2, 3), r.randn(3, 4)),
# invalid broadcast dimensions.
('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
# output should have ellipsis when broadcasting shape is non-empty.
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
]
for args in cases:
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = special_math_ops.einsum(*args)
placeholders = [
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
]
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(special_math_ops.einsum(args[0], *placeholders))
@test_util.disable_xla('b/131919749')
def test_empty(self):
def check(equation, input_shapes, output_shape):
# All these cases result in an output filled with zeros, so we don't call
# np.einsum. Also np.einsum doesn't support generalized diagonals which
# are needed for EinsumOp gradients.
r = np.random.RandomState(0)
inputs = [np.array(r.randn(*shape)) for shape in input_shapes]
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
output = self.evaluate(special_math_ops.einsum(equation, *input_tensors))
self.assertAllClose(output, np.zeros(output_shape), atol=1e-4, rtol=1e-4)
# Contractions along zero-sized dimensions.
check('ab,bc->ac', [(0, 10), (10, 10)], (0, 10))
# From transformer xl.
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
# Generalized traces with zero-sized dimensions.
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
@test_util.run_all_in_graph_and_eager_modes
class EinsumGradTest(test.TestCase):
def _check_gradient(self, s, *input_shapes):
with self.cached_session():
r = np.random.RandomState(0)
inputs = [np.array(r.randn(*shape)) for shape in input_shapes]
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: special_math_ops.einsum(s, *xs), input_tensors)
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), 1e-4)
@test_util.disable_xla('b/131919749')
def test_unary(self):
self._check_gradient('->', ())
self._check_gradient('aaa->a', (3, 3, 3))
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
self._check_gradient('abcd->da', (3, 5, 4, 2))
@test_util.disable_xla('b/131919749')
def test_unary_ellipsis(self):
self._check_gradient('...->...', ())
self._check_gradient('...->', ())
self._check_gradient('->...', ())
# Tests from dask
self._check_gradient('a...a->a...', (2, 2))
self._check_gradient('a...a->', (2, 2))
self._check_gradient('a...a->...', (2, 5, 1, 2))
self._check_gradient('a...a->a...', (2, 1, 2))
self._check_gradient('a...a->a...', (2, 3, 4, 5, 2))
self._check_gradient('...ijk->...ki', (3, 4, 5))
self._check_gradient('...ijk->...ki', (1, 3, 4, 5))
self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5))
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
def test_binary_simple(self):
# Binary cases in XLA mode must have either (a) each index appearing
# exactly once in both the inputs (batch or contraction index), or
# (b) appearing exactly once in an input and in the output (free index).
self._check_gradient(',->', (), ())
self._check_gradient('a,a->', (3,), (3,))
self._check_gradient('a,a->a', (3,), (3,))
self._check_gradient('ab,b->a', (3, 4), (4,))
self._check_gradient('ab,ab->', (3, 4), (3, 4))
self._check_gradient('ab,bc->ac', (3, 4), (4, 5))
self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4))
self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
def test_empty(self):
# From Transformer XL.
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
@test_util.disable_xla('b/131919749')
def test_reduced_indices(self):
self._check_gradient('ba,b->', (3, 2), (3,))
self._check_gradient('ab,ab->', (3, 4), (3, 4))
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
@test_util.disable_xla('b/131919749')
def test_repeated_indices(self):
# Repeated indices.
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
self._check_gradient('ill,k->ik', (2, 3, 3), (4,))
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
@test_util.disable_xla('b/131919749')
def test_empty_with_repeated_indices(self):
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
@test_util.disable_xla('b/131919749')
def test_broadcasting(self):
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
# Tests from dask.
self._check_gradient('...i,...j,...k->...ijk', (1, 4, 1, 2), (5, 1, 1, 3),
(1, 1, 1, 1, 9))
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
def test_long_cases(self):
cases = [
'abhe,hidj,jgba,hiab,gab->ed',
# Tests from dask.
'ea,fb,abcd,gc,hd->efgh',
]
dimension_map = dict(
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
for equation in cases:
inputs = equation.split('->')[0].replace(' ', '')
input_shapes = []
for input_str in inputs.split(','):
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
self._check_gradient(equation, *input_shapes)
@test_util.disable_xla('b/131919749')
def test_long_cases_with_repeated_labels(self):
cases = [
# Tests from dask.
'fdf,cdd,ccd,afe->ae',
'fff,fae,bef,def->abd',
]
dimension_map = dict(
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
for equation in cases:
inputs = equation.split('->')[0].replace(' ', '')
input_shapes = []
for input_str in inputs.split(','):
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
self._check_gradient(equation, *input_shapes)
class EinsumBenchmark(test.Benchmark):
cases = [
# Unary cases.
['ijk->i', 100],
['ijk->kji', 100],
# Regular matmul or batch matmul.
['ij,jk->ik', 500],
['ji,kj->ik', 500],
['bij,bjk->bik', 100],
['bji,bjk->bki', 100],
['ikl,kji->kl', 100],
['klj,lki->ij', 100],
['ijk,ilj->kli', 100],
['ijk,jklm->il', 50],
# Larger binary contractions.
['efabc,eabcd->efd', 20],
['fabec,abcde->fde', 20],
['efabc,edabc->efd', 20],
['eadbf,dfebc->ecfad', 20],
['abcdef,bcdfg->abcdeg', 20],
# Chain matmul.
['ij,jk,kl->il', 1000],
# Long cases. Path optimization should kick in.
['ea,fb,abcd,gc,hd->efgh', 10],
['bca,cdb,dbf,afc->', 10],
['efc,dbc,acf,fd->abe', 10],
['abhe,hidj,jgba,hiab,gab->ed', 10],
]
def benchmark_einsum(self):
for equation, dim in self.cases:
with ops.Graph().as_default(), \
session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device('/cpu:0'):
r = np.random.RandomState(0)
input_subscripts = equation.split('->')[0].split(',')
input_vars = []
for subscript in input_subscripts:
input_shape = (dim,) * len(subscript)
input_vars.append(
variables.Variable(np.array(r.randn(*input_shape), np.float32)))
self.evaluate(variables.global_variables_initializer())
if len(input_vars) <= 2:
self.run_op_benchmark(
sess,
special_math_ops.einsum(equation, *input_vars),
min_iters=50,
name='einsum_cpu_({})_{}'.format(equation, dim))
else:
for optimize in ['greedy', 'auto']:
self.run_op_benchmark(
sess,
special_math_ops.einsum(
equation, *input_vars, optimize=optimize),
min_iters=50,
name='einsum_cpu_({})_{}_{}'.format(equation, optimize, dim))
if __name__ == '__main__':
test.main()