blob: 1c6f3de8023878f756f9fd011edcd87221a596d8 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for TPU."""
import contextlib
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.tpu import tpu_replication
def enclosing_tpu_context():
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
return enclosing_tpu_context_and_graph()[0]
def enclosing_tpu_context_and_graph():
"""Returns the TPUReplicateContext which exists inside a tpu.rewrite(), and its associated graph."""
graph = ops.get_default_graph()
while graph is not None:
ctx = graph._get_control_flow_context() # pylint: disable=protected-access
while ctx is not None:
if isinstance(ctx, tpu_replication.TPUReplicateContext):
return ctx, graph
ctx = ctx.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None, None
@contextlib.contextmanager
def outside_or_skip_tpu_context():
"""Returns a context manager that skips current enclosing context if there is any."""
ctx, graph = enclosing_tpu_context_and_graph()
if ctx is None:
yield
else:
saved_context = graph._get_control_flow_context() # pylint: disable=protected-access
graph._set_control_flow_context(ctx.outer_context) # pylint: disable=protected-access
yield
graph._set_control_flow_context(saved_context) # pylint: disable=protected-access
@contextlib.contextmanager
def _maybe_enter_graph(tensor):
# Note: might have an eager tensor but not be executing eagerly when
# building functions.
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
ops.has_default_graph()):
yield
else:
with tensor.graph.as_default():
yield
@contextlib.contextmanager
def _maybe_on_device(var):
# Add a device scope for packed variables.
if isinstance(var, packed.PackedVarAndDevice):
with ops.device(var.device):
yield
else:
yield
def make_raw_assign_fn(raw_assign_fn, use_handle=True):
"""Wrap `raw_assign_fn` with the proper graph context and device scope.
Args:
raw_assign_fn: the function to be wrapped.
use_handle: if True, the `raw_assign_fn` will be applied to the handle of a
variable; otherwise it will be applied to the variable itself.
Returns:
The wrapped function.
"""
def assign_fn(var, value, use_locking=False, name=None, read_value=True):
del use_locking # Unused.
handle = var.handle if use_handle else var
with _maybe_enter_graph(handle), _maybe_on_device(var):
op = raw_assign_fn(
handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
with ops.control_dependencies([op]):
if read_value:
return var._read_variable_op() if use_handle else var.read_value() # pylint: disable=protected-access
else:
return op
return assign_fn
def make_raw_scatter_xxx_fn(raw_scatter_xxx_fn):
"""Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle."""
def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None): # pylint: disable=missing-docstring
del use_locking # Unused.
handle = var.handle
with _maybe_enter_graph(handle), _maybe_on_device(var):
op = raw_scatter_xxx_fn(
handle,
sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, var.dtype),
name=name)
with ops.control_dependencies([op]):
return var._read_variable_op() # pylint: disable=protected-access
return scatter_xxx_fn