blob: d99491abca37aa8755c9d762499b78384e40eb96 [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Autograph specific overrides for objects covered by tensor_util.is_tf_type."""
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops.parallel_for import control_flow_ops as parallel_ops
def wrap_py_func(f, args, kwargs=None):
"""Helper that wraps a callable to py_func.
The helper passes tensor arguments through the py_func interface. Non-tensor
arguments are allowed, and will be passed to f directly. Note that non-tensor
arguments are captured by f will not update every time the wrapper is
called (this is consistent with its argument list, which only includes
the tensor arguments). In general, it's safest not to reuse this wrapper.
Args:
f: Callable
args: Positional arguments for f, as list or tuple.
kwargs: Keyword arguments for f, as dict with string keys. May be None.
Returns:
The return values of f converted to tensor.
Raises:
ValueError: if any of the arguments are incorrect.
"""
tensor_args = []
tensor_args_idx = {}
# Of the positional arguments, only grab the tensor ones to be passed through
# the py_func.
n_args = len(args)
arg_is_tensor = tuple(map(tensor_util.is_tf_type, args))
for i in range(n_args):
if arg_is_tensor[i]:
tensor_args_idx[i] = len(tensor_args)
tensor_args.append(args[i])
# We essentially take the tensor kwargs, if any, and add them to the list of
# positional arguments. The kwargs are then reconstructed inside the py_func.
#
# For example, if
#
# args = [Tensor(1), 'foo']
# kwargs = {'a': Tensor(2), 'b': 'bar'}
#
# Then
#
# tensor_args = (Tensor(1), Tensor(2))
# kwarg_keys = ('a', 'b')
if kwargs:
kwarg_keys = tuple(kwargs.keys())
kwarg_is_tensor = {k: tensor_util.is_tf_type(kwargs[k]) for k in kwarg_keys}
for k in kwarg_keys:
if kwarg_is_tensor[k]:
tensor_args_idx[k] = len(tensor_args)
tensor_args.append(kwargs[k])
else:
kwarg_keys = ()
def f_wrapper(*tensor_args):
f_args = tuple(
tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a
for i, a in enumerate(args)
)
f_kwargs = {
k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k]
for i, k in enumerate(kwarg_keys)
}
f(*f_args, **f_kwargs)
return 1
return script_ops.eager_py_func(f_wrapper, tensor_args, dtypes.int32)
def _tf_py_func_print(*objects, **kwargs):
"""Overload of print_ as a py_func implementation."""
override_kwargs = {
k: v for k, v in kwargs.items() if v is not py_builtins.UNSPECIFIED
}
if 'flush' not in override_kwargs:
# Defaulting to flushing the console in graph mode, which helps reduce
# garbled output in IPython.
override_kwargs['flush'] = True
def print_wrapper(*vals, **kwargs):
vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals)
# TensorFlow doesn't seem to generate Unicode when passing strings to
# py_func. This causes the print to add a "b'" wrapper to the output,
# which is probably never what you want.
vals = tuple(v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
print(*vals, **kwargs)
return wrap_py_func(print_wrapper, objects, override_kwargs)
def _tf_sorted(iterable, key, reverse):
"""Overload of sorted_ for Tensor iterable."""
if reverse is py_builtins.UNSPECIFIED:
direction = 'ASCENDING'
else:
direction = 'DESCENDING'
if key is not py_builtins.UNSPECIFIED:
mapped = parallel_ops.vectorized_map(key, iterable)
if mapped.shape.rank is not None and mapped.shape.rank != 1:
raise ValueError('sort only supports only 1D tensors')
with ops.control_dependencies([
check_ops.assert_rank_v2(mapped, 1,
'sort only supports only 1D tensors')
]):
order = sort_ops.argsort(mapped, direction=direction)
return array_ops.gather_v2(iterable, order)
if iterable.shape.rank is not None and iterable.shape.rank != 1:
raise ValueError('sort only supports only 1D tensors')
with ops.control_dependencies([
check_ops.assert_rank_v2(iterable, 1,
'sort only supports only 1D tensors')
]):
return sort_ops.sort(iterable, direction=direction)
py_builtins.print_registry.register(
tensor_util.tf_type_classes, _tf_py_func_print
)
py_builtins.sorted_registry.register(
tensor_util.tf_type_classes, _tf_sorted
)