| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Arithmetic Operations that don't fit into math_ops due to dependencies. |
| |
| To avoid circular dependencies, some math_ops should go here. |
| """ |
| |
| import collections |
| import functools |
| import re |
| import string |
| |
| import numpy as np |
| import opt_einsum |
| |
| |
| from tensorflow.compiler.tf2xla.ops import gen_xla_ops |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gen_linalg_ops |
| from tensorflow.python.ops import gen_special_math_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import dispatch |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| # TODO(b/27419586) Change docstring for required dtype of x once int allowed |
| @tf_export('math.lbeta', v1=['math.lbeta', 'lbeta']) |
| @dispatch.add_dispatch_support |
| @deprecation.deprecated_endpoints('lbeta') |
| def lbeta(x, name=None): |
| r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. |
| |
| Given one-dimensional $z = [z_1,...,z_K]$, we define |
| |
| $$Beta(z) = \frac{\prod_j \Gamma(z_j)}{\Gamma(\sum_j z_j)},$$ |
| |
| where $\Gamma$ is the gamma function. |
| |
| And for $n + 1$ dimensional $x$ with shape $[N_1, ..., N_n, K]$, we define |
| |
| $$lbeta(x)[i_1, ..., i_n] = \log{|Beta(x[i_1, ..., i_n, :])|}.$$ |
| |
| In other words, the last dimension is treated as the $z$ vector. |
| |
| Note that if $z = [u, v]$, then |
| |
| $$Beta(z) = \frac{\Gamma(u)\Gamma(v)}{\Gamma(u + v)} |
| = \int_0^1 t^{u-1} (1 - t)^{v-1} \mathrm{d}t,$$ |
| |
| which defines the traditional bivariate beta function. |
| |
| If the last dimension is empty, we follow the convention that the sum over |
| the empty set is zero, and the product is one. |
| |
| Args: |
| x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| The logarithm of \\(|Beta(x)|\\) reducing along the last dimension. |
| """ |
| # In the event that the last dimension has zero entries, we return -inf. |
| # This is consistent with a convention that the sum over the empty set 0, and |
| # the product is 1. |
| # This is standard. See https://en.wikipedia.org/wiki/Empty_set. |
| with ops.name_scope(name, 'lbeta', [x]): |
| x = ops.convert_to_tensor(x, name='x') |
| |
| # Note reduce_sum([]) = 0. |
| log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1]) |
| |
| # Note lgamma(0) = infinity, so if x = [] |
| # log_gamma_sum_x = lgamma(0) = infinity, and |
| # log_prod_gamma_x = lgamma(1) = 0, |
| # so result = -infinity |
| sum_x = math_ops.reduce_sum(x, axis=[-1]) |
| log_gamma_sum_x = math_ops.lgamma(sum_x) |
| result = log_prod_gamma_x - log_gamma_sum_x |
| |
| return result |
| |
| |
| @tf_export('math.special.dawsn') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def dawsn(x, name=None): |
| """Computes Dawson's integral of `x` element-wise. |
| |
| Dawson's integral is defined as `exp(-x**2)` times the integral of |
| `exp(t**2)` from `0` to `x`, with the domain of definition all real numbers. |
| |
| Dawson's function is odd. |
| >>> tf.math.special.dawsn([-1., -0.5, 0.5, 1.]).numpy() |
| array([-0.5380795, -0.4244364, 0.4244364, 0.5380795], dtype=float32) |
| |
| This implementation is based off of the Cephes math library. |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.dawsn |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'dawsn', [x]): |
| return gen_special_math_ops.dawsn(x) |
| |
| |
| @tf_export('math.special.expint') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def expint(x, name=None): |
| """Computes the Exponential integral of `x` element-wise. |
| |
| The Exponential integral is defined as the integral of `exp(t) / t` from |
| `-inf` to `x`, with the domain of definition all positive real numbers. |
| |
| >>> tf.math.special.expint([1., 1.1, 2.1, 4.1]).numpy() |
| array([ 1.8951179, 2.1673784, 5.3332353, 21.048464], dtype=float32) |
| |
| This implementation is based off of the Cephes math library. |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.expi |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'expint', [x]): |
| return gen_special_math_ops.expint(x) |
| |
| |
| @tf_export('math.special.fresnel_cos') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def fresnel_cos(x, name=None): |
| """Computes Fresnel's cosine integral of `x` element-wise. |
| |
| The Fresnel cosine integral is defined as the integral of `cos(t^2)` from |
| `0` to `x`, with the domain of definition all real numbers. |
| |
| The Fresnel cosine integral is odd. |
| >>> tf.math.special.fresnel_cos([-1., -0.1, 0.1, 1.]).numpy() |
| array([-0.7798934 , -0.09999753, 0.09999753, 0.7798934 ], dtype=float32) |
| |
| This implementation is based off of the Cephes math library. |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.fresnel second output. |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'fresnel_cos', [x]): |
| return gen_special_math_ops.fresnel_cos(x) |
| |
| |
| @tf_export('math.special.fresnel_sin') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def fresnel_sin(x, name=None): |
| """Computes Fresnel's sine integral of `x` element-wise. |
| |
| The Fresnel sine integral is defined as the integral of `sin(t^2)` from |
| `0` to `x`, with the domain of definition all real numbers. |
| |
| >>> tf.math.special.fresnel_sin([-1., -0.1, 0.1, 1.]).numpy() |
| array([-0.43825912, -0.00052359, 0.00052359, 0.43825912], dtype=float32) |
| |
| This implementation is based off of the Cephes math library. |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.fresnel first output. |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'fresnel_sin', [x]): |
| return gen_special_math_ops.fresnel_sin(x) |
| |
| |
| @tf_export('math.special.spence') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def spence(x, name=None): |
| """Computes Spence's integral of `x` element-wise. |
| |
| Spence's integral is defined as the integral of `log(t) / (1 - t)` from |
| `1` to `x`, with the domain of definition all non-negative real numbers. |
| |
| >>> tf.math.special.spence([0.5, 1., 2., 3.]).numpy() |
| array([ 0.58224034, 0. , -0.82246685, -1.4367464], dtype=float32) |
| |
| This implementation is based off of the Cephes math library. |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.spence |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'spence', [x]): |
| return gen_special_math_ops.spence(x) |
| |
| |
| @tf_export('math.bessel_i0', 'math.special.bessel_i0') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_i0(x, name=None): |
| """Computes the Bessel i0 function of `x` element-wise. |
| |
| Modified Bessel function of order 0. |
| |
| It is preferable to use the numerically stabler function `i0e(x)` instead. |
| |
| >>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy() |
| array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.i0 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_i0', [x]): |
| return gen_special_math_ops.bessel_i0(x) |
| |
| |
| @tf_export('math.bessel_i0e', 'math.special.bessel_i0e') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_i0e(x, name=None): |
| """Computes the Bessel i0e function of `x` element-wise. |
| |
| Modified Bessel function of order 0. |
| |
| >>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy() |
| array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.i0e |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_i0e', [x]): |
| return gen_special_math_ops.bessel_i0e(x) |
| |
| |
| @tf_export('math.bessel_i1', 'math.special.bessel_i1') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_i1(x, name=None): |
| """Computes the Bessel i1 function of `x` element-wise. |
| |
| Modified Bessel function of order 1. |
| |
| It is preferable to use the numerically stabler function `i1e(x)` instead. |
| |
| >>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy() |
| array([-0.5651591 , -0.25789431, 0.25789431, 0.5651591 ], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.i1 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_i1', [x]): |
| return gen_special_math_ops.bessel_i1(x) |
| |
| |
| @tf_export('math.bessel_i1e', 'math.special.bessel_i1e') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_i1e(x, name=None): |
| """Computes the Bessel i1e function of `x` element-wise. |
| |
| Modified Bessel function of order 1. |
| |
| >>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy() |
| array([-0.20791042, -0.15642083, 0.15642083, 0.20791042], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.i1e |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_i1e', [x]): |
| return gen_special_math_ops.bessel_i1e(x) |
| |
| |
| @tf_export('math.special.bessel_k0') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_k0(x, name=None): |
| """Computes the Bessel k0 function of `x` element-wise. |
| |
| Modified Bessel function of order 0. |
| |
| It is preferable to use the numerically stabler function `k0e(x)` instead. |
| |
| >>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy() |
| array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.k0 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_k0', [x]): |
| return gen_special_math_ops.bessel_k0(x) |
| |
| |
| @tf_export('math.special.bessel_k0e') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_k0e(x, name=None): |
| """Computes the Bessel k0e function of `x` element-wise. |
| |
| Modified Bessel function of order 0. |
| |
| >>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy() |
| array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.k0e |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_k0e', [x]): |
| return gen_special_math_ops.bessel_k0e(x) |
| |
| |
| @tf_export('math.special.bessel_k1') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_k1(x, name=None): |
| """Computes the Bessel k1 function of `x` element-wise. |
| |
| Modified Bessel function of order 1. |
| |
| It is preferable to use the numerically stabler function `k1e(x)` instead. |
| |
| >>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy() |
| array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.k1 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_k1', [x]): |
| return gen_special_math_ops.bessel_k1(x) |
| |
| |
| @tf_export('math.special.bessel_k1e') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_k1e(x, name=None): |
| """Computes the Bessel k1e function of `x` element-wise. |
| |
| Modified Bessel function of order 1. |
| |
| >>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy() |
| array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.k1e |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_k1e', [x]): |
| return gen_special_math_ops.bessel_k1e(x) |
| |
| |
| @tf_export('math.special.bessel_j0') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_j0(x, name=None): |
| """Computes the Bessel j0 function of `x` element-wise. |
| |
| Modified Bessel function of order 0. |
| |
| >>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy() |
| array([ 0.93846981, 0.76519769, 0.22389078, -0.39714981], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.j0 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_j0', [x]): |
| return gen_special_math_ops.bessel_j0(x) |
| |
| |
| @tf_export('math.special.bessel_j1') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_j1(x, name=None): |
| """Computes the Bessel j1 function of `x` element-wise. |
| |
| Modified Bessel function of order 1. |
| |
| >>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy() |
| array([ 0.24226846, 0.44005059, 0.57672481, -0.06604333], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.j1 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_j1', [x]): |
| return gen_special_math_ops.bessel_j1(x) |
| |
| |
| @tf_export('math.special.bessel_y0') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_y0(x, name=None): |
| """Computes the Bessel y0 function of `x` element-wise. |
| |
| Modified Bessel function of order 0. |
| |
| >>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy() |
| array([-0.44451873, 0.08825696, 0.51037567, -0.01694074], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.y0 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_y0', [x]): |
| return gen_special_math_ops.bessel_y0(x) |
| |
| |
| @tf_export('math.special.bessel_y1') |
| @dispatch.register_unary_elementwise_api |
| @dispatch.add_dispatch_support |
| def bessel_y1(x, name=None): |
| """Computes the Bessel y1 function of `x` element-wise. |
| |
| Modified Bessel function of order 1. |
| |
| >>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy() |
| array([-1.47147239, -0.78121282, -0.10703243, 0.39792571], dtype=float32) |
| |
| Args: |
| x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, |
| `float32`, `float64`. |
| name: A name for the operation (optional). |
| |
| Returns: |
| A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. |
| |
| @compatibility(scipy) |
| Equivalent to scipy.special.y1 |
| @end_compatibility |
| """ |
| with ops.name_scope(name, 'bessel_y1', [x]): |
| return gen_special_math_ops.bessel_y1(x) |
| |
| |
| @ops.RegisterGradient('XlaEinsum') |
| def _einsum_grad(op, grad): |
| equation = op.get_attr('equation') |
| if isinstance(equation, bytes): |
| equation = equation.decode() |
| |
| inputs, output = equation.split('->') |
| left, right = inputs.split(',') |
| |
| return [ |
| gen_xla_ops.xla_einsum( |
| grad, |
| op.inputs[1], |
| equation='{},{}->{}'.format(output, right, left), |
| name=None), |
| gen_xla_ops.xla_einsum( |
| grad, |
| op.inputs[0], |
| equation='{},{}->{}'.format(output, left, right), |
| name=None) |
| ] |
| |
| |
| def _enclosing_tpu_context(): |
| # pylint: disable=protected-access |
| context = ops.get_default_graph()._get_control_flow_context() |
| # pylint: enable=protected-access |
| while context is not None and not isinstance( |
| context, control_flow_ops.XLAControlFlowContext): |
| context = context.outer_context |
| return context |
| |
| |
| @tf_export('einsum', 'linalg.einsum') |
| @dispatch.add_dispatch_support |
| def einsum(equation, *inputs, **kwargs): |
| r"""Tensor contraction over specified indices and outer product. |
| |
| Einsum allows defining Tensors by defining their element-wise computation. |
| This computation is defined by `equation`, a shorthand form based on Einstein |
| summation. As an example, consider multiplying two matrices A and B to form a |
| matrix C. The elements of C are given by: |
| |
| $$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$ |
| |
| or |
| |
| ``` |
| C[i,k] = sum_j A[i,j] * B[j,k] |
| ``` |
| |
| The corresponding einsum `equation` is: |
| |
| ``` |
| ij,jk->ik |
| ``` |
| |
| In general, to convert the element-wise equation into the `equation` string, |
| use the following procedure (intermediate strings for matrix multiplication |
| example provided in parentheses): |
| |
| 1. remove variable names, brackets, and commas, (`ik = sum_j ij * jk`) |
| 2. replace "*" with ",", (`ik = sum_j ij , jk`) |
| 3. drop summation signs, and (`ik = ij, jk`) |
| 4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`) |
| |
| Note: If the output indices are not specified repeated indices are summed. |
| So `ij,jk->ik` can be simplified to `ij,jk`. |
| |
| Many common operations can be expressed in this way. For example: |
| |
| **Matrix multiplication** |
| |
| >>> m0 = tf.random.normal(shape=[2, 3]) |
| >>> m1 = tf.random.normal(shape=[3, 5]) |
| >>> e = tf.einsum('ij,jk->ik', m0, m1) |
| >>> # output[i,k] = sum_j m0[i,j] * m1[j, k] |
| >>> print(e.shape) |
| (2, 5) |
| |
| Repeated indices are summed if the output indices are not specified. |
| |
| >>> e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] |
| >>> print(e.shape) |
| (2, 5) |
| |
| |
| **Dot product** |
| |
| >>> u = tf.random.normal(shape=[5]) |
| >>> v = tf.random.normal(shape=[5]) |
| >>> e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i] |
| >>> print(e.shape) |
| () |
| |
| **Outer product** |
| |
| >>> u = tf.random.normal(shape=[3]) |
| >>> v = tf.random.normal(shape=[5]) |
| >>> e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] |
| >>> print(e.shape) |
| (3, 5) |
| |
| **Transpose** |
| |
| >>> m = tf.ones(2,3) |
| >>> e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j] |
| >>> print(e.shape) |
| (3, 2) |
| |
| **Diag** |
| |
| >>> m = tf.reshape(tf.range(9), [3,3]) |
| >>> diag = tf.einsum('ii->i', m) |
| >>> print(diag.shape) |
| (3,) |
| |
| **Trace** |
| |
| >>> # Repeated indices are summed. |
| >>> trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] |
| >>> assert trace == sum(diag) |
| >>> print(trace.shape) |
| () |
| |
| **Batch matrix multiplication** |
| |
| >>> s = tf.random.normal(shape=[7,5,3]) |
| >>> t = tf.random.normal(shape=[7,3,2]) |
| >>> e = tf.einsum('bij,bjk->bik', s, t) |
| >>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k] |
| >>> print(e.shape) |
| (7, 5, 2) |
| |
| This method does not support broadcasting on named-axes. All axes with |
| matching labels should have the same length. If you have length-1 axes, |
| use `tf.squeeze` or `tf.reshape` to eliminate them. |
| |
| To write code that is agnostic to the number of indices in the input |
| use an ellipsis. The ellipsis is a placeholder for "whatever other indices |
| fit here". |
| |
| For example, to perform a NumPy-style broadcasting-batch-matrix multiplication |
| where the matrix multiply acts on the last two axes of the input, use: |
| |
| >>> s = tf.random.normal(shape=[11, 7, 5, 3]) |
| >>> t = tf.random.normal(shape=[11, 7, 3, 2]) |
| >>> e = tf.einsum('...ij,...jk->...ik', s, t) |
| >>> print(e.shape) |
| (11, 7, 5, 2) |
| |
| Einsum **will** broadcast over axes covered by the ellipsis. |
| |
| >>> s = tf.random.normal(shape=[11, 1, 5, 3]) |
| >>> t = tf.random.normal(shape=[1, 7, 3, 2]) |
| >>> e = tf.einsum('...ij,...jk->...ik', s, t) |
| >>> print(e.shape) |
| (11, 7, 5, 2) |
| |
| Args: |
| equation: a `str` describing the contraction, in the same format as |
| `numpy.einsum`. |
| *inputs: the inputs to contract (each one a `Tensor`), whose shapes should |
| be consistent with `equation`. |
| **kwargs: |
| - optimize: Optimization strategy to use to find contraction path using |
| opt_einsum. Must be 'greedy', 'optimal', 'branch-2', 'branch-all' or |
| 'auto'. (optional, default: 'greedy'). |
| - name: A name for the operation (optional). |
| |
| Returns: |
| The contracted `Tensor`, with shape determined by `equation`. |
| |
| Raises: |
| ValueError: If |
| - the format of `equation` is incorrect, |
| - number of inputs or their shapes are inconsistent with `equation`. |
| """ |
| return _einsum_v2(equation, *inputs, **kwargs) |
| |
| |
| def _einsum_v1(equation, *inputs, **kwargs): |
| """Legacy implementation of einsum without using EinsumOp.""" |
| name = kwargs.pop('name', None) |
| if kwargs: |
| raise TypeError( |
| f'Invalid keyword arguments for this function: ' |
| f'{", ".join([format(key) for key in sorted(list(kwargs.keys()))])}.' |
| f' Expected: name.') |
| with ops.name_scope(name, 'einsum', [equation, inputs]) as name: |
| inputs = list(inputs) |
| input_shapes = [x.shape for x in inputs] |
| input_axis_labels, output_axis_labels = ( |
| _einsum_v1_parse_and_resolve_equation(equation, input_shapes)) |
| |
| axis_labels = set(''.join(input_axis_labels) + output_axis_labels) |
| |
| for a in axis_labels: |
| for input_labels in input_axis_labels: |
| if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and |
| input_labels == input_labels[::-1] and '->' not in equation): |
| return math_ops.trace(inputs[0]) |
| if input_labels.count(a) > 1: |
| raise ValueError( |
| f'Subscript not supported: the axis {a} appears more than once' |
| f' in {input_labels}.') |
| for a in axis_labels: |
| input_count = sum(1 for s in input_axis_labels if a in s) |
| if input_count > 2 and a not in output_axis_labels: |
| logging.warn( |
| f'Falling back to exponential-space implementation of einsum()' |
| f' because index {a} is summed over more than two inputs.') |
| return _exponential_space_einsum_v1(equation, *inputs) |
| |
| # Use xla_einsum if executing on TPU and if the operation is a 2 input |
| # einsum supported by XlaEinsumOp. |
| if _enclosing_tpu_context() is not None and len(inputs) == 2: |
| return gen_xla_ops.xla_einsum( |
| inputs[0], inputs[1], input_axis_labels[0] + ',' + |
| input_axis_labels[1] + '->' + output_axis_labels) |
| temp = inputs[0] |
| temp_axis_labels = input_axis_labels[0] |
| for i in range(len(inputs) - 1): |
| axes_to_sum = ( |
| set(temp_axis_labels) & |
| set(input_axis_labels[i + 1]) - set(output_axis_labels)) |
| temp, temp_axis_labels = _einsum_v1_reduction(temp, temp_axis_labels, |
| inputs[i + 1], |
| input_axis_labels[i + 1], |
| axes_to_sum) |
| |
| missing_indices = set(temp_axis_labels) - set(output_axis_labels) |
| if missing_indices: |
| axis = [ |
| i for i, a in enumerate(temp_axis_labels) |
| if a not in output_axis_labels |
| ] |
| temp = math_ops.reduce_sum(temp, axis=axis) |
| temp_axis_labels = ''.join( |
| a for a in temp_axis_labels if a in output_axis_labels) |
| if sorted(temp_axis_labels) != sorted(output_axis_labels): |
| raise ValueError( |
| f'Invalid equation: {equation}. The computed and specified output ' |
| f'labels do not match: {temp_axis_labels} vs {output_axis_labels}.') |
| |
| perm = [temp_axis_labels.index(a) for a in output_axis_labels] |
| return _transpose_if_necessary(temp, perm) |
| |
| |
| def _einsum_v1_parse_and_resolve_equation(equation, input_shapes): |
| """Helper for einsum() that splits/resolves inputs & outputs. |
| |
| Args: |
| equation: Equation string given as argument to einsum(). |
| input_shapes: List of the shapes of all inputs given to einsum() |
| |
| Returns: |
| input_axis_labels, output_axis_labels where: |
| input_axis_labels: List of length len(input_shapes) of strings |
| representing the character label for each dimension of each given input, |
| resolving any broadcast (...) axes, |
| output_axis_labels: A string of character labels for each axes of output |
| tensor, filling in missing output subscripts and broadcast axes. |
| |
| Raises: |
| ValueError: If equation is in the uncorrect format, incorrect number of |
| inputs given or broadcast axes "..." or output axes could not be resolved. |
| """ |
| equation = equation.replace(' ', '') |
| match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation) |
| if not match: |
| raise ValueError(f'Indices have incorrect format. Received: {equation}.') |
| |
| input_axis_labels = match.group(1).split(',') |
| output_axis_labels = match.group(2)[2:] if match.group(2) else None |
| |
| if len(input_shapes) != len(input_axis_labels): |
| raise ValueError( |
| f'Got {len(input_shapes)} arguments for equation "{equation}", ' |
| f'expecting {len(input_axis_labels)}.') |
| |
| # Resolve Ellipsis |
| # Assign axes labels for unspecified dimensions in inputs. Labels taken |
| # from unused labels. Follow numpy einsum broadcasting conventions for |
| # tensors of different length and unlabeled output. |
| ellipsis_axes = '' |
| if '...' in equation: |
| unused = ''.join( |
| c for c in string.ascii_letters if c not in ''.join(input_axis_labels)) |
| for i, ax in enumerate(input_axis_labels): |
| if '...' in ax: |
| parts = ax.split('...') |
| if len(parts) != 2: |
| raise ValueError(f'Unable to resolve ellipsis. ' |
| f'Excess number found: {len(parts)-1} vs 1.') |
| if input_shapes[i].ndims is None: |
| raise ValueError('Unable to statically infer ellipsis axes. The ' |
| 'input shapes has a dynamic dimensionality.') |
| n = input_shapes[i].ndims - len(''.join(parts)) |
| if n < 0: |
| raise ValueError('Ellipses lengths do not match.') |
| if len(unused) < n: |
| raise ValueError( |
| 'Unable to resolve ellipsis, too many distinct labels.') |
| replace_axes = unused[-n:] if n > 0 else '' |
| input_axis_labels[i] = input_axis_labels[i].replace('...', |
| replace_axes) |
| if len(replace_axes) > len(ellipsis_axes): |
| ellipsis_axes = replace_axes |
| |
| if any('.' in ax for ax in input_axis_labels): |
| raise ValueError( |
| f'Period "." found outside of ellipsis in input {input_axis_labels}.') |
| |
| if output_axis_labels is not None: |
| output_axis_labels = output_axis_labels.replace('...', ellipsis_axes) |
| if '.' in output_axis_labels: |
| raise ValueError(f'Period "." found outside of ellipsis in output ' |
| f'{output_axis_labels}.') |
| |
| if output_axis_labels is None: |
| # infer the output subscripts if not given, assume alphabetical order, |
| # but always place ellipsis axes before given. |
| axis_labels = set(''.join(input_axis_labels)) - set(ellipsis_axes) |
| indices = ''.join(sorted(axis_labels)) |
| counts = {ax: 0 for ax in indices} |
| for axes_ in input_axis_labels: |
| for ax in axes_: |
| if ax not in ellipsis_axes: |
| counts[ax] += 1 |
| |
| output_axis_labels = ellipsis_axes + ''.join( |
| sorted(ax for ax in axis_labels if counts[ax] == 1)) |
| |
| return input_axis_labels, output_axis_labels |
| |
| |
| def _einsum_v1_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): |
| """Helper for einsum() that computes the result of a two-argument einsum(). |
| |
| Args: |
| t0: a `Tensor` |
| t0_axis_labels: a string of axis labels. This string's length must equal |
| the rank of t0. |
| t1: a `Tensor` |
| t1_axis_labels: a string to axis labels. This string's length must equal |
| the rank of t1. |
| axes_to_sum: set of labels of axes to be summed over |
| |
| Returns: |
| A `Tensor` whose elements are obtained by summing, over all axes in |
| `axes_to_sum`, the corresponding elements of `t0` and `t1`. |
| |
| For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and |
| axes_to_sum == {j,k}, this will return a tensor x where |
| |
| out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l] |
| |
| Raises: |
| ValueError: if the rank of `t0` does not match the length of |
| `t0_axis_labels`, or that of `t1` does not match the length of |
| `t1_axis_labels`. |
| """ |
| if len(t0_axis_labels) != len(t0.shape): |
| raise ValueError( |
| f'Tensor `t0` of rank {len(t0.shape)} does not match einsum reduction ' |
| f'of length {len(t0_axis_labels)}.') |
| if len(t1_axis_labels) != len(t1.shape): |
| raise ValueError( |
| f'Tensor `t1` of rank {len(t1.shape)} does not match einsum reduction ' |
| f'of length {len(t1_axis_labels)}') |
| |
| # This function computes the result of a two-argument einsum() using batch |
| # matrix multiplication. This involves |
| # 1. transposing t0 and t1 so that axes are in the correct order for |
| # batch matrix multiplication, and |
| # 2. reshaping t0 and t1 so that they are both of rank 3. |
| |
| # First, we divide axes into three groups: |
| # * "preserved" axes are present in both inputs and the output |
| # * "summed" axes are present in both inputs but not the output |
| # * "broadcast" axes are present in exactly one input and the output |
| # |
| # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a |
| # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are |
| # summed axes. |
| assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum) |
| preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum |
| broadcast_axes = {} |
| for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]): |
| broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum |
| |
| # Reorder the axes so that: |
| # 1. preserved axes come first in both inputs |
| # 2. in input 0, broadcast axes come next, followed by summed axes |
| # 3. in input 1, summed axes come next, followed by broadcast axes |
| def sort_key(input_index, a): |
| if a in preserved_axes: |
| return (-1, a) |
| elif ((input_index == 0 and a in broadcast_axes[0]) or |
| (input_index == 1 and a in axes_to_sum)): |
| return (0, a) |
| else: |
| return (1, a) |
| |
| axis_labels = [t0_axis_labels, t1_axis_labels] |
| sorted_axes = [ |
| sorted(sym_list, key=lambda a: sort_key(i, a)) |
| for i, sym_list in enumerate(axis_labels) |
| ] |
| inputs = [t0, t1] |
| for i, axes_str in enumerate(axis_labels): |
| perm = [axes_str.find(a) for a in sorted_axes[i]] |
| inputs[i] = _transpose_if_necessary(inputs[i], perm) |
| t0, t1 = inputs |
| |
| if not axes_to_sum: |
| # In the special case where there are no axes to sum over, reduce to mul() |
| # rather than to batch matrix multiplication. |
| for _ in broadcast_axes[1]: |
| t0 = array_ops.expand_dims(t0, -1) |
| for _ in broadcast_axes[0]: |
| t1 = array_ops.expand_dims(t1, len(preserved_axes)) |
| product = math_ops.multiply(t0, t1) |
| product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):] |
| return product, ''.join(product_axes) |
| else: |
| # Reduce to matmul(). |
| |
| # Reshape both inputs so as to combine multiple broadcast axes |
| # into a single axis, and combine multiple summed axes into a |
| # single axis. |
| |
| t0_shape = _get_shape(t0) |
| num_broadcast_elements_t0 = _total_size( |
| t0_shape[len(preserved_axes):-len(axes_to_sum)]) |
| num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):]) |
| new_shape = ( |
| t0_shape[:len(preserved_axes)] + |
| [num_broadcast_elements_t0, num_summed_elements]) |
| t0 = _reshape_if_necessary(t0, new_shape) |
| |
| t1_shape = _get_shape(t1) |
| num_broadcast_elements_t1 = _total_size( |
| t1_shape[len(preserved_axes) + len(axes_to_sum):]) |
| new_shape = ( |
| t1_shape[:len(preserved_axes)] + |
| [num_summed_elements, num_broadcast_elements_t1]) |
| t1 = _reshape_if_necessary(t1, new_shape) |
| |
| product = math_ops.matmul(t0, t1) |
| |
| # Undo compaction of broadcast axes |
| uncompacted_shape = ( |
| t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] + |
| t1_shape[len(t1_shape) - len(broadcast_axes[1]):]) |
| product = _reshape_if_necessary(product, uncompacted_shape) |
| |
| product_axes = ( |
| sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] + |
| sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):]) |
| |
| return product, ''.join(product_axes) |
| |
| |
| def _transpose_if_necessary(tensor, perm): |
| """Like transpose(), but avoids creating a new tensor if possible.""" |
| if perm != list(range(len(perm))): |
| return array_ops.transpose(tensor, perm=perm) |
| else: |
| return tensor |
| |
| |
| def _reshape_if_necessary(tensor, new_shape): |
| """Like reshape(), but avoids creating a new tensor if possible.""" |
| # Accept None as an alias for -1 in new_shape. |
| new_shape = tuple(-1 if x is None else x for x in new_shape) |
| cur_shape = tuple(x.value for x in tensor.shape.dims) |
| if (len(new_shape) == len(cur_shape) and |
| all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1) |
| for d0, d1 in zip(cur_shape, new_shape))): |
| return tensor |
| else: |
| return array_ops.reshape(tensor, new_shape) |
| |
| |
| def _get_shape(tensor): |
| """Like get_shape().as_list(), but explicitly queries the shape of a tensor |
| if necessary to ensure that the returned value contains no unknown value.""" |
| |
| shape = tensor.shape.as_list() |
| none_indices = [i for i, d in enumerate(shape) if d is None] |
| if none_indices: |
| # Query the shape if shape contains None values |
| shape_tensor = array_ops.shape(tensor) |
| for i in none_indices: |
| shape[i] = shape_tensor[i] |
| return shape |
| |
| |
| def _total_size(shape_values): |
| """Given list of tensor shape values, returns total size. |
| If shape_values contains tensor values (which are results of |
| array_ops.shape), then it returns a scalar tensor. |
| If not, it returns an integer.""" |
| |
| result = 1 |
| for val in shape_values: |
| result *= val |
| return result |
| |
| |
| def _exponential_space_einsum_v1(equation, *inputs): |
| """Fallback implementation that supports summing an index over > 2 inputs.""" |
| inputs = list(inputs) |
| input_shapes = [x.shape for x in inputs] |
| idx_in, idx_out = _einsum_v1_parse_and_resolve_equation( |
| equation, input_shapes) |
| |
| idx_all = set(''.join(idx_in) + idx_out) |
| indices = ''.join(sorted(idx_all)) |
| |
| missing_idx = set(idx_out).difference(idx_all) |
| if missing_idx: |
| raise ValueError(f'Unknown output axes: {missing_idx}.') |
| |
| axis_order = {} |
| for ax in indices: |
| if ax not in idx_out: |
| axis_order[ax] = len(axis_order) |
| for ax in idx_out: |
| axis_order[ax] = len(axis_order) |
| |
| # transpose inputs so axes are in order |
| for i, (input_, axes_) in enumerate(zip(inputs, idx_in)): |
| if input_.shape.ndims != len(axes_): |
| raise ValueError( |
| f'Input {i} with axes {axes_} has incorrect number of dimensions ' |
| f'(expected {len(axes_)}, got {input_.shape.ndims}).') |
| |
| sorted_idx = sorted(axes_, key=axis_order.get) |
| |
| if len(set(axes_)) != len(axes_): |
| raise ValueError( |
| f'Subscript not supported: an axis appears more than once: {axes_}.') |
| |
| if list(axes_) != sorted_idx: |
| permuted = [axes_.find(ax) for ax in sorted_idx] |
| inputs[i] = array_ops.transpose(input_, permuted) |
| idx_in[i] = sorted_idx |
| |
| reduction_idx = [] |
| shapes = [[dim if dim else -1 |
| for dim in tensor.shape.as_list()] |
| for tensor in inputs] |
| |
| # validate shapes for broadcasting |
| for j, ax in enumerate(sorted(idx_all, key=axis_order.get)): |
| dims = [] |
| for i, idx in enumerate(idx_in): |
| if ax not in idx: |
| shapes[i].insert(j, 1) |
| else: |
| dim = shapes[i][j] |
| if isinstance(dim, int) and dim > 1: |
| dims.append(dim) |
| |
| if len(set(dims)) > 1: |
| raise ValueError(f'Dimension mismatch on axis: {ax}. ' |
| f'Found {len(set(dims))}, expected 1.') |
| |
| if ax not in idx_out: |
| reduction_idx.append(j) |
| |
| # reshape, multiply |
| expanded_inputs = [ |
| array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes) |
| ] |
| expanded_output = 1 |
| for input_ in expanded_inputs: |
| expanded_output *= input_ |
| |
| # contract |
| return math_ops.reduce_sum(expanded_output, reduction_idx) |
| |
| |
| def _einsum_v2(equation, *inputs, **kwargs): |
| """Implementation of einsum utilizing opt_einsum and EinsumOp.""" |
| name = kwargs.pop('name', None) |
| optimize = kwargs.pop('optimize', 'greedy') |
| if kwargs: |
| raise TypeError( |
| f'Invalid keyword arguments for einsum: {", ".join(kwargs)}. ' |
| f'Valid arguments: name, optimize, greedy.') |
| |
| with ops.name_scope(name, 'einsum', [equation, inputs]) as name: |
| inputs = list(inputs) |
| input_shapes = [] |
| for operand in inputs: |
| if isinstance(operand.shape, tensor_shape.TensorShape): |
| input_shapes.append(operand.shape.as_list() if operand.shape else None) |
| else: |
| input_shapes.append(list(operand.shape)) |
| # Validate and sanitize the equation and resolve static input shapes, as |
| # opt_einsum requires that all shapes be a tuple of positive integers. |
| # Also remove ellipsis from the equation as opt_einsum will replace them |
| # with named labels. Then broadcasting between different shapes or ranks |
| # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]). |
| resolved_equation, resolved_input_shapes, ellipsis_label = ( |
| _einsum_v2_parse_and_resolve_equation(equation, input_shapes)) |
| |
| if len(inputs) <= 2: # No need to call opt_einsum. |
| # Replace back ellipses that were removed for opt_einsum. |
| if ellipsis_label: |
| resolved_equation = resolved_equation.replace(ellipsis_label, '...') |
| return gen_linalg_ops.einsum(inputs, resolved_equation) |
| |
| # Send fully specified shapes to opt_einsum, since it cannot handle unknown |
| # dimensions. For unknown dimensions, we guess that the dimension equals 1. |
| # Instead of creating Tensors or NumPy arrays with the specified shape, |
| # create a dummy `shaped` object with a `shape` property. |
| shaped = collections.namedtuple('shaped', ['shape']) |
| shaped_inputs = tuple( |
| [shaped(tuple(shape)) for shape in resolved_input_shapes]) |
| # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums. |
| # Obtain the sequence of equations and the indices of operands involved in |
| # each einsum operation. |
| indices_and_equations = _get_opt_einsum_contract_path( |
| resolved_equation, shaped_inputs, optimize) |
| for operand_indices, binary_equation in indices_and_equations: |
| if ellipsis_label: |
| # Replace back ellipses that were removed for opt_einsum. |
| binary_equation = binary_equation.replace(ellipsis_label, '...') |
| operands = list(map(inputs.pop, operand_indices)) |
| inputs.append(gen_linalg_ops.einsum(operands, binary_equation)) |
| return inputs[0] |
| |
| |
| def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize): |
| """Returns the (memoized) result of opt_einsum.contract_path.""" |
| # Note: We use einsum_call=True, which is an internal api for opt_einsum, |
| # to get the contraction path without having opt_einsum perform the actual |
| # contractions. |
| _, contractions = opt_einsum.contract_path( |
| equation, |
| *shaped_inputs_tuple, |
| optimize=optimize, |
| einsum_call=True, |
| use_blas=True) |
| # Return a tuple so that the cached value is not mutable. |
| indices_and_equations = tuple([(expr[0], expr[2]) for expr in contractions]) |
| return indices_and_equations |
| |
| |
| # Cache the possibly expensive opt_einsum.contract_path call using lru_cache |
| # from the Python3+ standard library. |
| _get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)( |
| _get_opt_einsum_contract_path) |
| |
| |
| def _einsum_v2_parse_and_resolve_equation(equation, input_shapes): |
| """Helper which validates einsum equation and resolves input shapes.""" |
| resolved_equation = equation.replace(' ', '') |
| ellipsis_label = None |
| if '...' in equation: |
| # Replace ellipsis ('...') with '0' for (a) ease of parsing and (b) to |
| # prevent opt_einsum from resolving them into named labels; as it doesn't |
| # support broadcasting. |
| ellipsis_label = '0' |
| if ellipsis_label in resolved_equation: |
| raise ValueError( |
| f'Invalid character "{ellipsis_label}" in equation: {equation}.') |
| resolved_equation = resolved_equation.replace('...', ellipsis_label) |
| |
| # Ensure there are no non-alphanumeric characters in the equation, including |
| # periods (`.`) outside of ellipses, in the equation. This is not a hard |
| # requirement; except we use a special character '0' for ellipsis. |
| allowed_labels = 'a-zA-Z' |
| if ellipsis_label: |
| allowed_labels += ellipsis_label |
| match = re.match('^([{0},]*)(->[{0}]*)?$'.format(allowed_labels), |
| resolved_equation) |
| if not match: |
| raise ValueError( |
| 'Subscripts have incorrect format: {}'.format(resolved_equation)) |
| input_labels = match.group(1).split(',') |
| output_labels = match.group(2)[2:] if match.group(2) else None |
| |
| if len(input_shapes) != len(input_labels): |
| raise ValueError('Got {} inputs for equation "{}", expecting {}'.format( |
| len(input_shapes), equation, len(input_labels))) |
| |
| # Special case: if there are no '->', then we create output subscripts from |
| # labels appearing only once. |
| if '->' not in resolved_equation: |
| label_counts = collections.Counter(match.group(1)) |
| output_labels = ''.join([ |
| x for x in sorted(list(label_counts)) |
| if x != ',' and label_counts[x] == 1 |
| ]) |
| resolved_equation += '->' + output_labels |
| # Validate output_labels. |
| if output_labels and len(set(output_labels)) != len(output_labels): |
| raise ValueError( |
| 'Output subscripts contain a label appearing more than once: {}'.format( |
| equation)) |
| input_label_set = set(match.group(1)) |
| for label in output_labels: |
| if label != ellipsis_label and label not in input_label_set: |
| raise ValueError('Output subscripts contain the label {} not present ' |
| 'in the input subscripts.'.format(label)) |
| if ellipsis_label and output_labels: |
| num_output_ellipses = output_labels.count(ellipsis_label) |
| if num_output_ellipses > 1: |
| raise ValueError( |
| 'Output subscripts contain multiple ellipsis: {}'.format(equation)) |
| |
| # Early return if <= 2 inputs. Resolved shapes are not needed. |
| if len(input_shapes) <= 2: |
| return resolved_equation, None, ellipsis_label |
| |
| # Create a map from axis labels to known dimensions. This is used to infer |
| # unknown dimensions if a known dimension also has the same label. |
| label_to_dim = collections.defaultdict(lambda: 1) |
| for i, (labels, shape) in enumerate(zip(input_labels, input_shapes)): |
| if shape is None: |
| continue |
| ellipsis_start = labels.find(ellipsis_label) if ellipsis_label else -1 |
| if ellipsis_start != -1: # This input contains an ellipsis. |
| if ellipsis_start != labels.rfind(ellipsis_label): |
| raise ValueError(f'Too many ellipses in input label ' |
| f'{labels.replace(ellipsis_label, "...")}.') |
| if len(labels) > len(shape) + 1: |
| raise ValueError('Too many named labels in {}th subscript string of' |
| ' equation {} for input shape {} '.format( |
| i, equation, shape)) |
| ellipsis_end = ellipsis_start + len(shape) + 1 - len(labels) |
| shape[ellipsis_start:ellipsis_end] = ([ |
| np.prod( |
| list(filter(None, shape[ellipsis_start:ellipsis_end])), |
| dtype=np.int64) |
| ]) |
| else: |
| # This input does not contain an ellipsis. |
| if len(labels) != len(shape): |
| raise ValueError( |
| 'Number of named labels in input #{} of equation {} ' |
| 'must be equal to the number of dimensions in shape {}'.format( |
| i, equation, shape)) |
| for dim, label in zip(shape, labels): |
| if dim is not None: |
| label_to_dim[label] = max(label_to_dim[label], dim) |
| |
| resolved_shapes = [] |
| for labels in input_labels: |
| resolved_shapes.append([label_to_dim[label] for label in labels]) |
| return resolved_equation, resolved_shapes, ellipsis_label |