blob: 5067aa7f6be823947255422ff9a0cf6d806c2119 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations for ExtensionTypes (aka Composite Tensors)."""
from tensorflow.core.protobuf import composite_tensor_variant_pb2
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_composite_tensor_ops
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.util import nest
def composite_tensor_to_variants(value, type_spec=None, name=None):
"""Encodes `value` as a scalar variant tensor.
Args:
value: The `ExtensionType` value to encode.
type_spec: Information about the value's type that should be included in the
encoding.
name: Optional name for the operation.
Returns:
A Tensor with shape=`()` and dtype=`tf.variant`.
Raises:
ValueError: If `type_spec` is not compatible with `value`.
"""
if not isinstance(value, composite_tensor.CompositeTensor):
raise TypeError("Expected `value` to be a CompositeTensor. "
f"Received {type(value)}.")
if type_spec is None:
type_spec = value._type_spec # pylint: disable=protected-access
if not type_spec.is_compatible_with(value):
raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` "
f"{value!r}.")
metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
metadata.type_spec_proto.CopyFrom(
nested_structure_coder.encode_structure(type_spec).type_spec_value)
return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
components=nest.flatten(value, expand_composites=True),
metadata=metadata.SerializeToString(),
name=name)
def composite_tensor_from_variant(encoded, type_spec, name=None):
"""Returns the `ExtensionType` value encoded by a variant scalar tensor.
Args:
encoded: A Tensor returned by `composite_tensor_to_variants`.
type_spec: The `TypeSpec` of the original value. This is used to determine
the number and types of the component tensors that comprise the decoded
value. Must be compatible with the `TypeSpec` serilized in `encoded`.
name: Optional name for the operation.
Returns:
An `ExtensionType` value that is compatible with `TypeSpec`.
Raises:
TypeError: If `encoded` is not a Tensor with dtype=variant.
InvalidArgumentError: If `encoded` is not compatible with `type_spec`.
"""
if not isinstance(encoded, ops.Tensor):
raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.")
if encoded.dtype != dtypes.variant:
raise TypeError("Expected `encoded` to have dtype=variant, got "
f"{encoded!r}.")
encoded.shape.assert_is_compatible_with(())
metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
metadata.type_spec_proto.CopyFrom(
nested_structure_coder.encode_structure(type_spec).type_spec_value)
component_dtypes = [
t.dtype for t in nest.flatten(type_spec, expand_composites=True)
]
components = gen_composite_tensor_ops.CompositeTensorVariantToComponents(
encoded=encoded,
metadata=metadata.SerializeToString(),
Tcomponents=component_dtypes,
name=name)
return nest.pack_sequence_as(type_spec, components, expand_composites=True)
@ops.RegisterGradient("CompositeTensorVariantFromComponents")
def _composite_tensor_to_variants_grad(op, grad):
return gen_composite_tensor_ops.CompositeTensorVariantToComponents(
encoded=grad,
metadata=op.get_attr("metadata"),
Tcomponents=op.get_attr("Tcomponents"))
@ops.RegisterGradient("CompositeTensorVariantToComponents")
def _composite_tensor_from_variant_grad(op, *grad):
assert len(grad) == len(op.outputs)
# `components` is `op.outputs`, but with any tensors for which we're
# taking the gradient replaced by the corresponding value from `grad`.
components = [
op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad))
]
return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
components=components, metadata=op.get_attr("metadata"))