blob: 8a2e69e46961c5d290cb9b9c66d62216d30a4e78 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========================================================================
"""A utility to trace tensor values on TPU."""
import collections
import hashlib
import operator
import os
import os.path
import sys
import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_case
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import summary_ops_v2 as summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import analytics
from tensorflow.python.platform import gfile
from tensorflow.python.platform import remote_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary_iterator
from tensorflow.python.tpu import tensor_tracer_flags
from tensorflow.python.tpu import tensor_tracer_report
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.training import training_util
_DEVICE_TYPE_TPU = 'tpu'
_DEVICE_TYPE_CPU = 'cpu'
_TRACE_MODE_PART_TENSOR_SIZE = 3
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op'
_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow'
_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
_REASON_SKIP_SCALAR = 'not-traced-scalar'
_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
_REASON_SCALAR_GET_TRACED = 'traced-scalar'
_REASON_TENSOR_GET_TRACED = 'traced-tensor'
_REASON_USER_INCLUDED = 'traced-user-included'
_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
_OUTPUT_STREAM_ESCAPE = 'file://'
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
_TRACE_FILE_NAME = 'trace.all'
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
_TT_SNAPSHOT = 'tensor_tracer_snapshot'
_REPLICA_ID_TAG = '#replica-id: '
_SKIP_REPORT_FILE = 'None' # Do not write report proto if --report_file=None
_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM
_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX
_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS
_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN
_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN
_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR
_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE
_TT_SUMMARY_SPARSITY = tensor_tracer_flags.TT_SUMMARY_SPARSITY
_TT_SUMMARY_TAG = 'tensor_tracer_summary'
_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
_TT_SUMMARY_MAX_QUEUE = 10
tt_gauge = monitoring.BoolGauge('/tensorflow/api/tensor_tracer/v1',
'tensor tracer usage', 'method')
def _graph_summary_tag(graph):
"""Generates and returns a summary tag name for the given graph."""
if graph is None:
raise RuntimeError('graph is None')
# The chance of collision with md5 is effectively 0.
hash_id = hashlib.md5()
hash_id.update(repr(graph).encode('utf-8'))
# hexdigest() returns a string.
return hash_id.hexdigest()
def set_parameters(tensor_tracer_params=None):
"""Enables tensor tracer and sets its parameters.
Example usage:
tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir',
'trace_mode': 'norm',
'report_file': '/usr/tmp/trace_dir/report.all'}
tensor_tracer.set_parameters(tensor_tracer_parameters)
This sets up the parameters for tensor tracer. A call to tensor tracer as
below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be
skipped as this call is hooked into tpu.rewrite.
tt = tensor_tracer.TensorTracer()
loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss)
Args:
tensor_tracer_params: Tensor tracer parameter dictionary. Below gives
examples of these parameters: See tensor_tracer_report.py for all
parameters.
- enable: If set, tensor tracer will be enabled. Calling
enable_tensor_tracer automatically adds this parameters.
- trace_mode: The trace_mode to be used by tensor tracer. These include:
- summary: Collects multiple statistics for traced tensors, and writes
them a summary file that can be visualized using tensorboard. This
mode currently only works for TPUEstimator. It can be also be used
for other models, but outfeed must be handled by the user.
- norm: Collects norm of each traced tensor and writes them into a
text file pointed by 'trace_dir' flag. (Default mode).
- nan-inf: Checks the existince of NaNs and Infs in the tensor, and
writes a boolean value to a text file pointed by 'trace_dir' flag.
Note that 'norm' mode can also capture this information with more
numerical info.
- max-abs: Collects the absolute max for each traced tensors and
writes it into a text file pointed by 'trace_dir' flag.
- full-tensor: Writes the full tensor content of the traced tensors
into a text file pointed by 'trace_dir' flag.
- part-tensor: Writes a part of the tensor content of the traced
tensors into a text file pointed by 'trace_dir' flag.
- full_tensor_summary: Writes the full tensors as binary event files.
The outputs can be read using: trace =
tensor_tracer.read_tensor_tracer_event_file(event_file_path)
- report_file: Path to the metadata file that is written during graph
construction. If not set, metadata will be printed to stdout during
graph construction.
- trace_dir: Path where the execution traces will be written during the
graph execution. If not set, trace will be printed to stderr.
- trace_level: Tensor tracer aims to trace everything it can. This
introduces some overhead on graph execution and graph compilation
times. Using trace_level parameter, it is possible to trace operation
based on their priorities. For example, - trace_level=7 is the highest
trace_level, in which every op is traced. - trace_level=6 will skip
constant operations such as tf.constant. - trace_level=5 will skip
less important ops such as tf.identities. - The default trace_level=3,
that will skip concat ops, or random number generators. - To reduce
the graph compile time overhead, trace_level can be set to 0, that
will skip additions, and substractions, and multiplications as well.
- excluded_opnames: If set, any matching op name will not be traced.
excluded_opnames can be set as a regular expression. E.g,
excluded_opnames=.* will exclude everything.
- excluded_optypes: If set, any matching op type will not be traced.
excluded_optypes can be set as a regular expression. E.g,
excluded_optypes=.* will exclude everything. excluded_optypes=MatMul
will exclude all MatMul ops from tracing.
- included_opnames: If set, any matching op name will be forced to be
traced. included_opnames can be set as a regular expression. E.g,
'--included_opnames=some_op --excluded_opname=*.' will only trace
some_op.
- included_optypes: If set, any matching op type will be forced to be
traced. included_optypes can be set as a regular expression. E.g,
'--included_optypes=some_op_type --excluded_optypes=*.' will trace
only the ops with type 'some_op_type'
- flush_summaries: If summary mode is used, flush_summaries=1 will
flush summaries using outside compilation. Note that, if used with
low level APIs, flush_summaries=1 is necessary to obtain results.
Advanced Flags:
- trace_scalar: Scalar values are not traced by default. If this flag is
set, scalar values will also be traced.
- op_range: In the form of '%d:%d' that limits the tracing to the ops
within this limit. --op_range='5:10' will trace only the ops that have
topological order between 5-10.
- submode: 'brief' or 'detailed'. If the trace mode is not compact,
brief mode will print only the id of each traced tensor to save some
space. 'detailed' mode prints the full tensor name.
- use_fingerprint_subdirectory: The trace directory will be chosen as
using the fingerprint of the trace metadata under the provided
trace_dir.
"""
enable_flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE
if tensor_tracer_params:
for key, value in tensor_tracer_params.items():
enable_flags += ' --%s=%s' % (key, value)
os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = enable_flags
def op_priority(op_type):
"""Returns the priority of the op.
If the priority of the op is k, it will be traced if trace_level>=k.
Args:
op_type: String name of the operation type.
Returns:
Integer value corresponding the priority of the op.
"""
if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
# Lowest priority ops, e.g., constant ops across different steps,
# They will be traced only if trace_level>=7
return 7
if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient',
'PreventGradient', 'Squeeze', 'Gather', 'GatherNd'):
# Operations without numerical effects.
# They will be only if trace_level>=6
return 6
if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile',
'CollectivePermute', 'SplitV', 'DynamicPartition'):
# Operations that merge or slice an input, will be traced if trace_level>=5
return 5
if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'):
# Operations less likely to provide useful information,
# will be traced if trace_level>=4
return 4
if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'):
# Add operations that are less likely create any issues, will be traced
# if trace_level>=3 (default=3)
return 3
if op_type in ('Neg', 'Sub'):
# Sub operations that are less likely create any issues, will be traced
# trace_level>=2
return 2
if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select',
'Maximum', 'Mean', 'Variance', 'Exp', 'Rsqrt'):
# Multiplication and some other operations, will be traced if trace_level>=1
return 1
# Unclassified op_types default to being traced at level 2 and above.
return 2
def read_tensor_tracer_event_file(event_file):
"""Reads the event file written by tensor tracer.
This can be used to read the full tensors written into binary event files by
by TensorTracer with trace_mode=full_tensor_summary.
Example usage:
result_dict_list = tensor_tracer.read_tensor_tracer_event_file(
event_file_path)
for result_dict in result_dict_list:
for step, tensor_dict in result_dict.items():
for tensor_name, full_tensor_content in tensor_dict.items():
logging.info(tensor_name, full_tensor_content)
Args:
event_file: Path to the event file that contains only tensor tracer events.
Returns:
A list of event dictionaries, each of which with the form:
{step_number: {tensor_name: tensor_content}}. This is a list instead of
a single event dictionary because it is possible that an event file may
have multiple event traces, each of them covering the same step ranges.
Raises:
ValueError: If an unexpected trace is found.
"""
# Keeps track of how many times that a step number shows up in these events.
step_occurrence_count = collections.defaultdict(int)
# List of step occurrences.
step_occurrence_list = []
for trace_event in summary_iterator.summary_iterator(event_file):
# First event is an event with file_version: "brain.Event:2"
if not trace_event.HasField('summary'):
continue
if len(trace_event.summary.value) != 1:
raise ValueError('Single step contains %d summary values,'
' expected 1.' % len(trace_event.summary.value))
step = trace_event.step
step_occurrence_count[step] += 1 # a new occurrence for this step.
occurrence_idx = step_occurrence_count[step] - 1
occurrence_size = len(step_occurrence_list)
if occurrence_idx == occurrence_size:
# This particular occurrence isn't yet recorded on step_occurrence_list.
# So append this new occurrence to the end of step_occurrence_list.
new_occurrence = collections.defaultdict(dict)
step_occurrence_list.append(new_occurrence)
else:
# This particular occurrence must be already recorded on
# step_occurrence_list (i.e. occurrence_idx < occurrence_size).
if occurrence_idx > occurrence_size:
raise ValueError('Unexpected: occurrence_idx (%d) > '
'occurrence_size (%d)' % (occurrence_idx,
occurrence_size))
tensor_value = trace_event.summary.value[0]
tensor_name = tensor_value.tag
real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
tensor_content = np.frombuffer(
tensor_value.tensor.tensor_content,
dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
).reshape(real_shape)
step_occurrence_list[occurrence_idx][step][tensor_name] = tensor_content
return step_occurrence_list
def trace_tensor(tensor, tracepoint_name=None):
"""Programmatic interface to trace a tensor with Tensor Tracer.
Tensor Tracer, by default, traces all tensors in the execution. This function
can be used to limit traced tensors. If this function is called for a subset
of the tensors, only those will be traced.
For example, Tensor Traacer will only trace c below.
c = tf.MatMul(a, b)
tensor_tracer.trace_tensor(c)
d = tf.add(c, 1)
Args:
tensor: the tensor object for which the tracing is requested.
tracepoint_name: an optional tensor tracepoint name string. A tracepoint
name is an Tensor Tracer internal name for the tensor. It is useful when
comparing equivalent traces from different models that have different
tensor namings. Equivalent tensors (with different names) can be mapped
to each other by assigning a common tracepoint_name.
Returns:
The provided tensor.
"""
if tracepoint_name is None:
tracepoint_name = tensor.name
tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
(tensor, tracepoint_name))
return tensor
def keras_layer_tracepoint(layer, checkpoint_name):
"""An interface for adding the tensor outputs of a keras layer.
Encapsulates trace_tensor.
Args:
layer: A keras layer.
checkpoint_name: a string name for the checkpoint. This name has to be a
unique name if used within model comparison. The tensors that have the same
checkpoint identifier is compared in model comparison.
Returns:
The provided layer.
"""
try:
outputs = layer.output
if tensor_util.is_tf_type(outputs):
trace_tensor(outputs, '%s' % (checkpoint_name))
else:
idx = 0
for output_tensor in outputs:
if tensor_util.is_tf_type(outputs):
trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx))
idx += 1
except AttributeError:
pass
except RuntimeError:
pass
return layer
class TensorTracer:
"""A software construct for tracing tensor values in a TF graph.
This utility is disabled by default. It is hooked into tpu.rewrite, so it can
easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as
below without a code change.
export TENSOR_TRACER_FLAGS="--enable=1"
Below is the use example to enable it on CPUs or GPUs, or for more advance use
cases on TPUs.
a = x + 1
b = a * 2
rs = tf.reduce_sum(b)
tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir',
'report_file: 'path/to/report/file'})
tt = tensor_tracer.TensorTracer()
if on_tpu:
rs = tt.trace_tpu(tf.get_default_graph(),
tensor_fetches=rs)
else:
rs = tt.trace_cpu(tf.get_default_graph(),
tensor_fetches=rs)
session.run(rs)
If it is enabled, it will trace the output tensor values of
selected Ops in the graph. It has two outputs: (1) the traces and (2)
a report. The traces are dumped to a specified directory during the graph
execution, while the report is dumped during the graph construction.
By passing options via the env variable, users can change:
(1) the trace mode (e.g., detecting NaN/Inf, printing partial or
full tensor values)
(2) which Ops to be traced (via op.name or op.type)
(3) output trace file path.
"""
# The set of graphs that are rewritten by tensor tracer.
_traced_graphs = set()
@staticmethod
def is_enabled():
"""Returns True if TensorTracer is enabled."""
try:
enable = tensor_tracer_flags.TTParameters().is_enabled()
# Add metrics to determine API usage.
if enable: tt_gauge.get_cell('is_enabled').set(True)
return enable
except (ValueError, RuntimeError) as e:
logging.warning(
'Tensor Tracer V1 flags processing error encountered in is_enabled '
'check. %s', e)
# TODO(b/210212559): Find a more robust fix.
# Should only produce exception if Tensor Tracer is enabled.
return True
@staticmethod
def check_device_type(device_type):
"""Checks if the given device type is valid."""
if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
raise ValueError('Invalid device_type "%s"'%device_type)
@staticmethod
def check_trace_mode(device_type, trace_mode):
"""Checks if the given trace mode work on the given device type.
Args:
device_type: Device type, TPU, GPU, CPU.
trace_mode: Tensor tracer trace mode.
Raises:
ValueError: If the given trace mode is not supported for the device.
"""
if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY:
if device_type != _DEVICE_TYPE_TPU:
raise ValueError('Device_type "%s" is not yet supported for '
'trace mode "%s"' % (device_type, trace_mode))
@staticmethod
def loop_cond_op(op):
return op.type in ('LoopCond', 'RefLoopCond')
@staticmethod
def while_loop_op(op):
"""Returns true if op is one of the special ops of in a while loop.
Args:
op: A tf.Operation.
Returns:
True if the given op is one of [Switch, Merge, Enter, Exit,
NextIteration, LoopCond], which are all building blocks for TF while
loops.
"""
return (control_flow_util.IsLoopSwitch(op) or
control_flow_util.IsLoopMerge(op) or
control_flow_util.IsLoopEnter(op) or
control_flow_util.IsLoopExit(op) or
TensorTracer.loop_cond_op(op) or
op.type in ('RefNextIteration', 'NextIteration'))
@staticmethod
def control_flow_op(op):
"""Returns true if op is one of the special ops of in a while loop.
Args:
op: A tf.Operation.
Returns:
True if the given op is one of [Switch, Merge, Enter, Exit,
NextIteration, LoopCond], which are all building blocks for TF while
loops.
"""
return (control_flow_util.IsSwitch(op) or
control_flow_util.IsMerge(op))
@staticmethod
def unsafe_op(op):
"""Returns True if this op is not safe to be traced."""
# Reasons for not including following op types:
# Assign: cause incorrect result with CPU tracing.
if op.type == 'Assign':
return True
return False
@staticmethod
def device_mismatch(device_type, op):
if device_type == _DEVICE_TYPE_TPU:
# pylint: disable=protected-access
return tpu_replication._TPU_REPLICATE_ATTR not in op.node_def.attr
# pylint: enable=protected-access
return False
@staticmethod
def unsafe_scalar_trace(op):
"""Return true if scalar output tensor from Op is not safe to be traced."""
# Tracing the following causes cycle in the graph on TPU.
if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
'Switch', 'Less', 'ReadVariableOp'):
return True
# Tracing the following will cause casting-issue
# with the norm tracing mode or other compilation issues on CPU.
if op.type in ('VarHandleOp', 'IteratorToStringHandle',
'IteratorGetNext', 'OneShotIterator',
'IteratorV2', 'MakeIterator',
'BatchDatasetV2', 'MapDataset',
'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
return True
return False
def _is_interesting_op(self, op):
"""Returns True if the given op is not an interesting one to be traced."""
return op_priority(op.type) <= self._parameters.trace_level
@staticmethod
def reason(op_idx, details):
"""Returns reason why the Op at op_idx is traced or not."""
return '%d %s'%(op_idx, details)
def __init__(self):
"""Initializes a TensorTracer.
Sets the various member fields from the flags (if given) or the defaults.
"""
self._replica_id = None
self._tt_config = tensor_tracer_report.TensorTracerConfig()
self._parameters = tensor_tracer_flags.TTParameters()
self._host_call_fn = {}
# _cache_variables is a dict (key = graph, value = dicts
# (key = name, value = tensors))
self._cache_variables = {}
self._history_value_cache = {}
self._traced_op_names = set()
self._report_proto = None
# _temp_cache_var is a dict (key = graph, value = [])
self._temp_cache_var = {}
self._report_proto_path = ''
self._outmost_context = None
def report_proto(self):
"""Getter for tensor_tracer.proto object for summary and full_tensor_summary modes.
Returns:
A tensor_tracer.proto object.
Raises:
ValueError if called before tracing happens, or when trace mode is not
summary or full_tensor_summary.
"""
if self._report_proto:
return self._report_proto
else:
raise ValueError('Call to report_proto must be done after tracing.'
'Report proto only exists for '
'trace_mode=[summary|full_tensor_summary]')
def report_proto_path(self):
"""Getter for path where tensor_tracer.proto object should be written.
Returns:
A string path.
"""
return self._report_proto_path
def _escape_namescopes(self, variable_name):
return variable_name.replace('/', '_').replace(':', '_')
def _cache_variable_for_graph(self, graph):
if graph not in self._cache_variables:
self._cache_variables[graph] = {}
return self._cache_variables[graph]
def _create_or_get_tensor_history_values_cache(self,
cache_name,
graph,
shape=None,
dtype=dtypes.float32):
"""Creates a variable as the cache to store historic intermediate tensor values.
Args:
cache_name: Name to be given to the cache (an instance of tf.variable).
graph: Tensorflow graph.
shape: A list of dimensions.
dtype: Data type of created cache.
Returns:
A ref to newly created or existing cache with the given dimensions.
Raises:
ValueError:
(1) If graph is None, or
(2) shape is None when a new cache needs to be created.
"""
if graph is None:
raise ValueError('Invalid graph.')
if graph not in self._history_value_cache:
self._history_value_cache[graph] = {}
if cache_name not in self._history_value_cache[graph]:
if shape is None:
raise ValueError('shape must be provided at cache creation.')
if dtype.is_integer:
init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
else:
init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
self._history_value_cache[graph][
cache_name] = variable_scope.get_variable(
'tt_history' + '_' + self._escape_namescopes(cache_name),
shape=shape,
dtype=dtype,
initializer=init_ops.constant_initializer(init_val),
trainable=False,
use_resource=True,
collections=[
_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES
])
return self._history_value_cache[graph][cache_name]
def _create_or_get_tensor_values_cache(self, cache_name, graph,
shape=None, dtype=dtypes.float32):
"""Creates a variable as the cache to store intermediate tensor values.
Args:
cache_name: Name to be given to the cache (an instance of tf.variable).
graph: Tensorflow graph.
shape: A list of dimensions.
dtype: Data type of created cache.
Returns:
A ref to newly created or existing cache with the given dimensions.
Raises:
ValueError:
(1) If graph is None, or
(2) shape is None when a new cache needs to be created.
"""
if graph is None:
raise ValueError('Invalid graph.')
graph_cache_var = self._cache_variable_for_graph(graph)
if cache_name not in graph_cache_var:
if shape is None:
raise ValueError('shape must be provided at cache creation.')
if dtype.is_integer:
init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
else:
init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
graph_cache_var[cache_name] = variable_scope.get_variable(
_TT_SNAPSHOT + '_' + self._escape_namescopes(cache_name),
shape=shape, dtype=dtype,
initializer=init_ops.constant_initializer(init_val),
trainable=False,
use_resource=True,
collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
return graph_cache_var[cache_name]
def _add_replica_id_to_graph(self):
"""Adds nodes for computing the replica ID to the graph."""
if self._tt_config.num_replicas:
with ops.control_dependencies(None):
# Uses None as dependency to run outside of TPU graph rewrites.
self._replica_id = tpu_ops.tpu_replicated_input(
list(range(self._tt_config.num_replicas)),
name='tt_replica_id')
else:
self._replica_id = 'unknown'
def _inside_op_range(self, idx):
"""Return True if the given index is inside the selected range."""
if idx < self._parameters.op_range[0]:
return False
return (self._parameters.op_range[1] < 0 or
idx <= self._parameters.op_range[1])
def _is_user_included_op(self, op):
"""Checks whether the op is included in the tensor tracer flags.
Args:
op: tf Operation
Returns:
True, if the op is included.
An op is included if:
- Its op name is given in included_opnames
- Its op type is given in included_optypes
- The op is at most _trace_ops_before_included hops before an included op
- The op is at most _trace_ops_after_included hops after an included op
"""
for opname_re in self._parameters.included_opname_re_list:
if opname_re.match(op.name):
return True
for optype_re in self._parameters.included_optype_re_list:
if optype_re.match(op.type):
return True
return False
def _is_user_excluded_op(self, op):
for opname_re in self._parameters.excluded_opname_re_list:
if opname_re.match(op.name):
return True
for optype_re in self._parameters.excluded_optype_re_list:
if optype_re.match(op.type):
return True
return False
def _signature_types(self):
"""Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
if self._parameters.trace_mode in set([
tensor_tracer_flags.TRACE_MODE_NAN_INF,
tensor_tracer_flags.TRACE_MODE_NORM,
tensor_tracer_flags.TRACE_MODE_HISTORY,
tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
return {self._parameters.trace_mode: 0}
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
return self._parameters.summary_signatures
return {}
def _num_signature_dimensions(self):
return len(self._signature_types())
def _use_temp_cache(self):
"""Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable.
Returns:
A boolean, denoting whether to use a temporary cache or not.
"""
# If full tensors need to be stored tf.variables, then do not use temp
# variables to store them.
if self._use_tensor_buffer():
return False
if self._use_tensor_values_cache():
return self._parameters.use_temp_cache_var
else:
# Temporary caches only replaces tf.Variables caches. If no cache is used
# return False.
return False
def _use_tensor_values_cache(self):
"""Returns True if immediate tensors should be first saved to a cache."""
return self._parameters.use_compact_trace
def _use_tensor_buffer(self):
"""Returns true if the whole tensor needs to be cached/buffered in memory."""
return (self._parameters.trace_mode ==
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
def _merge_tensor_signatures(self, signatures):
"""Returns a tensor that merges the given signatures.
Args:
signatures: A dictionary of the signature updates from signature name to
a tensor of dimension [1].
Returns:
A tensor that concats the signature values in a predefined order.
Raises:
ValueError: Unable to merge signatures.
"""
sorted_update = []
if self._num_signature_dimensions() > 1:
signature_indices = self._signature_types()
for _, val in sorted(signatures.items(),
key=lambda item: signature_indices[item[0]]):
sorted_update.append(val)
updates = array_ops_stack.stack(
sorted_update, axis=0, name='merge_single_op_signatures')
elif self._num_signature_dimensions() == 1:
# Avoid stack operation if there is only a single signature.
(_, val), = signatures.items()
updates = val
else:
raise ValueError('Cannot merge 0 signatures. Check the value passed for '
'flag --signatures.')
return updates
def _save_tensor_value_to_tmp_cache(self, cache_idx, updates, graph):
"""Returns an op that will save the given updates to an entry in the cache.
Args:
cache_idx: The cache index of the tensor within the cache.
updates: A dictionary of the signature updates from signature name to
a tensor of dimension [1].
graph: A TensorFlow graph.
Raises:
RuntimeError:
(1) graph is not already in self._temp_cache_var, or
(2) cache_idx is out of range.
"""
updates = self._merge_tensor_signatures(updates)
updates = array_ops.reshape(updates,
[self._num_signature_dimensions()])
if graph not in self._temp_cache_var:
raise RuntimeError('graph is not in self._temp_cache_var')
if cache_idx >= len(self._temp_cache_var[graph]):
raise RuntimeError('cache_idx (%d) is out of range (%d)' % (
cache_idx, len(self._temp_cache_var[graph])))
self._temp_cache_var[graph][cache_idx] = updates
def _save_tensor_value_to_cache_op(self, cache_idx, updates, graph):
"""Returns an op that will save the given updates to an entry in the cache.
Args:
cache_idx: The cache index of the tensor within the cache.
updates: A dictionary of the signature updates.
graph: A TensorFlow graph.
Returns:
Cache update operation.
"""
# state_ops.scatter_update allows updates only along the first dimension.
# Make a compact array by concatenating different signatures, and update
# them all together.
updates = self._merge_tensor_signatures(updates)
updates = array_ops.reshape(updates,
[1, self._num_signature_dimensions()])
indices = constant_op.constant([cache_idx])
cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
return state_ops.scatter_update(cache, indices, updates).op
def _snapshot_tensor(self, tensor):
"""Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.
Args:
tensor: tensor whose values will be stored in a new tf.Variable.
Returns:
An assignment operation.
"""
snapshot_variable = self._create_or_get_tensor_values_cache(
tensor.name, tensor.op.graph,
tensor.shape.as_list(), tensor.dtype)
return state_ops.assign(snapshot_variable, tensor).op
def _preprocess_traced_tensor(self, tensor):
"""Computes NAN/Norm/Max on TPUs before sending to CPU.
Args:
tensor: The tensor to be traced.
Returns:
A tensor that should be input to the trace_function.
Raises:
RuntimeError: If the signature is invalid.
"""
def _detect_nan_inf(tensor):
"""Trace function for detecting any NaN/Inf in the tensor."""
if tensor.dtype.is_floating:
mask = math_ops.reduce_any(
gen_math_ops.logical_or(
gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
output_tensor = cond.cond(
mask,
lambda: constant_op.constant([1.0]),
lambda: constant_op.constant([0.0]))
else:
output_tensor = constant_op.constant([0.0])
return output_tensor
def _compute_signature(tensor, tf_op, cast_to_f32=True):
if cast_to_f32:
tensor = math_ops.cast(tensor, dtypes.float32)
output_tensor = tf_op(tensor)
# Return type should be scalar. Set it if it does not have the
# information.
if not output_tensor.get_shape().is_fully_defined():
output_tensor = array_ops.reshape(output_tensor, [])
return output_tensor
def _show_size(tensor):
# In order to check the size of a tensor.
# Not all sizes are known at the compile time, also, different replicas
# sometimes get different sizes of tensors.
# Collect it here to be used in merging replica data.
tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
# Cast to float32, so that it can be placed into same cache with other
# signatures.
return math_ops.cast(tsize, dtypes.float32)
def _show_max(tensor, cast_to_f32=True):
# returns -inf for empty tensor
return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)
def _show_min(tensor, cast_to_f32=True):
# returns inf for empty tensor
return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)
def _show_norm(tensor, cast_to_f32=True):
# returns 0 for empty tensor
return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)
def _show_sparsity(tensor, cast_to_f32=True, tolerance=1e-06):
# returns nan for empty tensor and treats nans as non-zero numbers
def sparsity_fn(tensor):
non_zeros = math_ops.greater_equal(math_ops.abs(tensor), tolerance)
nans = math_ops.is_nan(tensor)
return nn_impl.zero_fraction(math_ops.logical_or(non_zeros, nans))
return _compute_signature(tensor, sparsity_fn, cast_to_f32)
def _show_mean_and_variance(tensor, cast_to_f32=True):
"""Returns the mean and variance of the given tensor."""
if cast_to_f32:
tensor = math_ops.cast(tensor, dtypes.float32)
# returns nan for empty tensor
mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
# The shape has to be 1. Set it if it does not have the information.
if not mean.get_shape().is_fully_defined():
mean = array_ops.reshape(mean, [])
if not var.get_shape().is_fully_defined():
var = array_ops.reshape(var, [])
return mean, var
def _show_max_abs(tensor, cast_to_f32=True):
return _compute_signature(
tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32)
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
if (self._parameters.trace_mode ==
tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
return {self._parameters.trace_mode: tensor}
if (self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
return {self._parameters.trace_mode: tensor}
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
return {self._parameters.trace_mode: array_ops.reshape(
_show_norm(tensor), [1])}
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_HISTORY:
return {self._parameters.trace_mode: array_ops.reshape(
_show_norm(tensor), [1])}
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
return {self._parameters.trace_mode: _show_max_abs(tensor)}
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
tensor = math_ops.cast(tensor, dtypes.float32)
result_dict = {}
# Call mean and variance computation here to avoid adding the same nodes
# twice.
if (_TT_SUMMARY_MEAN in self._signature_types() or
_TT_SUMMARY_VAR in self._signature_types()):
mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False)
for signature_name, _ in sorted(self._signature_types().items(),
key=lambda x: x[1]):
if signature_name == _TT_SUMMARY_NORM:
signature_result_tensor = _show_norm(tensor, cast_to_f32=False)
elif signature_name == _TT_SUMMARY_MAX:
signature_result_tensor = _show_max(tensor, cast_to_f32=False)
elif signature_name == _TT_SUMMARY_MAX_ABS:
signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False)
elif signature_name == _TT_SUMMARY_MIN:
signature_result_tensor = _show_min(tensor, cast_to_f32=False)
elif signature_name == _TT_SUMMARY_SPARSITY:
signature_result_tensor = _show_sparsity(tensor)
elif signature_name == _TT_SUMMARY_SIZE:
signature_result_tensor = _show_size(tensor)
elif signature_name == _TT_SUMMARY_MEAN:
signature_result_tensor = mean
elif signature_name == _TT_SUMMARY_VAR:
signature_result_tensor = variance
else:
raise ValueError('Unknown signature type :%s.' % signature_name)
result_dict[signature_name] = signature_result_tensor
return result_dict
raise RuntimeError(
'Unsupported signature for trace mode %s.'
% self._parameters.trace_mode)
def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
"""Makes the tensor tracing function called by outside compilation.
Args:
tensor_name: name of the tensor being traced.
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
Returns:
A function to be passed as the first argument to outside compilation.
Raises:
RuntimeError: If the trace mode is invalid.
"""
def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
"""Prints a tensor value to a file.
Args:
tensor_name: name of the tensor being traced.
num_elements: number of elements to print (-1 means print all).
tensor: the tensor needs to be returned.
output_tensor: the tensor needs to be printed.
Returns:
The same tensor passed via the "tensor" argument.
Raises:
ValueError: If tensor_name is not already in
tensor_trace_order.tensorname_to_cache_idx.
"""
if self._parameters.is_brief_mode():
if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
raise ValueError(
'Tensor %s with name %s is not in the tensorname_to_cache_idx' %
(tensor, tensor_name))
msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name]
else:
msg = '"%s"' % tensor_name
if self._parameters.trace_dir:
output_path = os.path.join(
self._parameters.trace_dir,
_TRACE_FILE_NAME + self._get_outfile_suffix())
output_stream = _OUTPUT_STREAM_ESCAPE + output_path
else:
output_stream = sys.stderr
return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
'@', self._replica_id,
'\n', output_tensor, '\n',
summarize=num_elements,
output_stream=output_stream)
def _show_part_tensor(tensor):
"""Trace function for printing part of the tensor."""
return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
tensor, tensor)
def _show_full_tensor(tensor):
"""Trace function for printing the entire tensor."""
return _print_tensor(tensor_name, -1, tensor, tensor)
if (self._parameters.trace_mode ==
tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
return _show_part_tensor
# The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
# TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
# performed within TPUs and only their results are transferred to CPU.
# Simply, print the full tensor for these trace modes.
if self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_NAN_INF,
tensor_tracer_flags.TRACE_MODE_NORM,
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
tensor_tracer_flags.TRACE_MODE_MAX_ABS,
tensor_tracer_flags.TRACE_MODE_SUMMARY,
tensor_tracer_flags.TRACE_MODE_HISTORY
):
return _show_full_tensor
raise RuntimeError('Full tensor support is not available with trace mode %s'
%self._parameters.trace_mode)
def _is_in_control_flow(self, op):
"""Returns true if the given op is inside a tf.cond or in tf.while_loop.
Args:
op: A tensorflow op that should be checked whether in control flow or not.
Returns:
A boolean value whether the op is in control flow or not.
"""
return control_flow_util.IsInCond(op)
def _is_in_outmost_while_loop(self, op):
"""Returns true if the op is at the same level with the training loop.
Returns false if the op is in an inner while loop or if it is outside of the
training loop.
Args:
op: tf.Operation
Returns:
A boolean.
"""
ctxt = self._get_op_control_flow_context(op)
outer_while_context = control_flow_util.GetContainingWhileContext(ctxt)
return outer_while_context == control_flow_util.GetContainingWhileContext(
self._outmost_context)
def _should_trace_in_control_flow(self):
"""Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop."""
# As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY
# forces the execution of the traced tensors. We should not trace the ops
# that may not be executed due to control flow.
if self._use_temp_cache():
return False
elif self._tt_config.device_type == _DEVICE_TYPE_TPU:
# On TPUs do not trace in control flow unless we use caches to store
# intermediate values as calling outside compilation within an inner loop
# causes errors.
return self._use_tensor_values_cache() or self._use_tensor_buffer()
return True
def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
"""Returns True if we should not trace Op.
Args:
op_id: Topological index of the op.
op: tf.Operation
ops_in_exec_path: Set of operations that are in the execution path.
report_handler: An instance of tensor_tracer_report.TTReportHandle.
Returns:
True if the op should not be traced, false otherwise.
"""
if TensorTracer.while_loop_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
return True
if TensorTracer.control_flow_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP))
return True
if TensorTracer.unsafe_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
return True
if TensorTracer.device_mismatch(self._tt_config.device_type, op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
return True
if op not in ops_in_exec_path:
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
return True
# TensorTracer will not trace the operations that are in an inner while loop
# or tf.cond when a temporary cache is used. Temporary cache adds direct
# data dependencies to traced operations, and needs a static number of
# traced operations. For these cases,
# - We do not know the number of slots required when there are inner while
# loops. TensorTracer can only trace the result of a while loop.
# - We do not know ahead of time which branch of the tf.cond
# will be taken, so we avoid introducing data dependencies for the
# operations inside a tf.cond.
# - We also cannot have a data dependency to an operation in a different
# while context.
if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op):
if not self._should_trace_in_control_flow():
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW))
return True
if self._is_user_included_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
if tensor_tracer_flags.TT_CHECK_FILTER.value:
logging.info('USER_INCLUDED op %s', op.name)
return False
if not self._inside_op_range(op_id):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
return True
if not self._is_interesting_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
return True
if self._is_user_excluded_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
if tensor_tracer_flags.TT_CHECK_FILTER.value:
logging.info('USER_EXCLUDED op %s', op.name)
return True
return False
def _skip_tensor(self, op_id, out_tensor, report_handler):
"""Returns True if we should not trace out_tensor.
Args:
op_id: Topological index of the op producing tensor.
out_tensor: tf.Tensor
report_handler: An instance of tensor_tracer_report.TTReportHandle.
Returns:
True if the tensor should not be traced, false otherwise.
"""
# Skips a tensor if the tensor has a non-numeric type.
# Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
# because it also excludes tensors with dtypes, bool, and
# float32_ref, which we actually want to trace.
non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
dtypes.string])
if out_tensor.dtype in non_numeric_tensor_types:
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
return True
# Skip a tensor if it feeds a special while loop op.
if [consumer for consumer in out_tensor.consumers() if
TensorTracer.while_loop_op(consumer)]:
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
return True
if self._is_user_included_op(out_tensor.op):
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
if tensor_tracer_flags.TT_CHECK_FILTER.value:
logging.info('USER_INCLUDED tensor %s', out_tensor.name)
return False
if self._is_user_excluded_op(out_tensor.op):
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
if tensor_tracer_flags.TT_CHECK_FILTER.value:
logging.info('USER_EXCLUDED tensor %s', out_tensor.name)
return True
if not out_tensor.get_shape().is_fully_defined():
# If trace mode is nan-inf, norm or max, then the tensor will be reduced
# to a scalar before the outside compilation call.
if self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_NAN_INF,
tensor_tracer_flags.TRACE_MODE_NORM,
tensor_tracer_flags.TRACE_MODE_HISTORY,
tensor_tracer_flags.TRACE_MODE_MAX_ABS,
tensor_tracer_flags.TRACE_MODE_SUMMARY
):
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
return False
else:
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
return True
rank = len(out_tensor.shape)
if rank < 1:
# scalar
if self._parameters.trace_scalar_ops:
if TensorTracer.unsafe_scalar_trace(out_tensor.op):
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
return True
else:
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
return False
else:
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
return True
else:
# tensor
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
return False
def _filter_execution_path_operations(self, operations, fetches):
"""Returns the set of ops in the execution path to compute given fetches."""
# If no fetch provided, then return all operations.
if fetches is None:
return set(operations)
# Convert to list, if a single element is provided.
if not isinstance(fetches, (list, tuple)):
fetches = [fetches]
# If a tensor is given as fetch, convert it to op.
op_fetches = []
for fetch in fetches:
if isinstance(fetch, ops.Operation):
op_fetches.append(fetch)
elif isinstance(fetch, ops.Tensor):
op_fetches.append(fetch.op)
else:
raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
%fetch)
execution_path_operations = set(op_fetches)
traverse_stack = list(op_fetches)
while True:
if not traverse_stack:
break
head_op = traverse_stack.pop()
input_ops = [tensor_input.op for tensor_input in head_op.inputs]
input_ops.extend(head_op.control_inputs)
for input_op in input_ops:
if input_op not in execution_path_operations:
# Filter out loop condition operations, tracing them causes a cycle.
# Trace only the loop-body.
if TensorTracer.loop_cond_op(input_op):
continue
execution_path_operations.add(input_op)
traverse_stack.append(input_op)
return execution_path_operations
def _determine_and_instrument_traced_tensors(self, graph_order,
ops_in_exec_path,
tensor_trace_points,
report_handler):
"""Determines the tensors to trace and instruments the trace details.
Args:
graph_order: graph_order tuple containing graph (tf.graph), operations
(list of operations), op_to_idx (op id mapping), (tensors) list of
tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
there is a cycle in the graph), topological_order_or_cycle (list of ops
in topological order or list of ops creating a cycle).
ops_in_exec_path: Set of ops in the execution path.
tensor_trace_points: Collection of programatic tensor trace points.
report_handler: An instance of tensor_tracer_report.TTReportHandle.
Returns:
List of tensors to be traced.
"""
traced_tensors = []
checkpoint_operations = set([tensor.op
for (tensor, _) in tensor_trace_points])
for op_id, op in enumerate(graph_order.operations):
if checkpoint_operations and op not in checkpoint_operations:
continue
if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
continue
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
if not self._skip_tensor(op_id, out_tensor, report_handler):
traced_tensors.append(out_tensor)
return traced_tensors
def _check_trace_files(self):
"""Checks if any requirements for trace files are satisfied."""
if not self._parameters.trace_dir:
# traces will be written to stderr. No need to check trace files.
return
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
# Output files are handled by tf.summary operations, no need to precreate
# them.
return
if not gfile.Exists(self._parameters.trace_dir):
file_io.recursive_create_dir(self._parameters.trace_dir)
if not gfile.Exists(self._parameters.trace_dir):
raise RuntimeError('Failed to create trace directory at %s' %
self._parameters.trace_dir)
def _create_temp_cache(self, num_traced_tensors, num_signatures, graph):
"""Creates a temporary cache with the given dimensions.
Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops
that have shape of [num_signatures].
Args:
num_traced_tensors: Int, denoting total number of traced tensors.
num_signatures: Int, denoting the number of statistics collected per
tensors.
graph: TensorFlow graph.
"""
init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
dtype=dtypes.float32,
shape=[num_signatures])
self._temp_cache_var[graph] = [
init_value for _ in range(num_traced_tensors)]
def _determine_trace_and_create_report(self, graph, ops_in_exec_path,
graph_summary_tag):
"""Work needs to be done prior to TPU or CPU tracing.
Args:
graph: tf.graph
ops_in_exec_path: Set of operations in the execution path.
graph_summary_tag: the summary tag name for the given graph.
Returns:
An instance of tensor_tracer_report.TensorTraceOrder, containing list of
tensors to be traced with their topological order information.
Raises:
RuntimeError: If opname filtering is incorrectly set.
"""
self._check_trace_files()
graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
report_handler = tensor_tracer_report.TTReportHandle()
traced_tensors = self._determine_and_instrument_traced_tensors(
graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors))
if traced_tensors and tensor_tracer_flags.TT_CHECK_FILTER.value:
raise RuntimeError('Verify ops being traced by tensor tracer.')
tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
traced_tensors)
num_signatures = self._num_signature_dimensions()
# Create a cache variable if compact_tracing is used.
if num_signatures and self._use_tensor_values_cache():
if self._use_temp_cache():
self._create_temp_cache(len(traced_tensors), num_signatures, graph)
else:
self._create_or_get_tensor_values_cache(
_TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures])
if self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_HISTORY):
self._create_or_get_tensor_history_values_cache(
_TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures])
if self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_SUMMARY,
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
self._report_proto = report_handler.create_report_proto(
self._tt_config, self._parameters, tensor_trace_order,
tensor_trace_points, self._signature_types())
if self._parameters.use_fingerprint_subdir:
self._parameters.trace_dir = os.path.join(
self._parameters.trace_dir, self._report_proto.fingerprint)
logging.info('TensorTracer updating trace_dir to %s',
self._parameters.trace_dir)
self._report_proto_path = report_handler.report_proto_path(
self._parameters.trace_dir, graph_summary_tag)
if self._parameters.report_file_path != _SKIP_REPORT_FILE:
report_handler.write_report_proto(self._report_proto_path,
self._report_proto, self._parameters)
else:
if self._parameters.trace_mode not in (
tensor_tracer_flags.TRACE_MODE_HISTORY):
report_handler.create_report(self._tt_config, self._parameters,
tensor_trace_order, tensor_trace_points)
return tensor_trace_order
def _create_host_call(self):
return self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_SUMMARY,
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream,
tensor_trace_order):
"""Generates a print operation to print trace inspection.
Args:
cache: Tensor storing the trace results for the step.
replica_id: Tensor storing the replica id of the running core.
step_num: Step number.
output_stream: Where to print the outputs, e.g., file path, or sys.stderr.
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
Returns:
The Op to flush the cache to file.
"""
def _inspect_tensor(tensor):
"""Returns the text to be printed for inspection output."""
if (self._parameters.trace_mode ==
tensor_tracer_flags.TRACE_MODE_NAN_INF):
return cond.cond(
math_ops.greater(tensor, 0.0),
lambda: 'has NaNs/Infs!',
lambda: 'has no NaNs or Infs.')
else:
return tensor
# Check if there are graph operations being profiled.
if not tensor_trace_order.traced_tensors:
logging.warn('Inspect mode has no tensors in the cache to check.')
return control_flow_ops.no_op
# Check if the cache includes any nan or inf
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
# Cache has 1s or 0s if the mode is NaN_INF
step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0)
else:
# Cache has the actual numerics for other modes.
step_has_nan_or_inf = math_ops.reduce_any(
gen_math_ops.logical_or(
gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache)))
# Summarizing message for each step.
step_error_message = cond.cond(
step_has_nan_or_inf,
lambda: 'NaNs or Infs in the step!',
lambda: 'No numerical issues have been found for the step.')
# No need to print core numbers if the cache is merged already.
if self._parameters.collect_summary_per_core:
stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->',
step_error_message,
'Printing tensors for mode:%s...' % self._parameters.trace_mode]
else:
stats = ['\n\n', 'step:', step_num, '-->', step_error_message,
'Printing tensors for mode:%s...' % self._parameters.trace_mode]
for tensor_name, cache_idx in sorted(
tensor_trace_order.tensorname_to_cache_idx.items(),
key=lambda item: item[1]):
if self._parameters.collect_summary_per_core:
stats.extend([
'\n', 'core:', replica_id, ',', 'step:', step_num, ',',
tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
else:
stats.extend([
'\n', 'step:', step_num, ',',
tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
return logging_ops.print_v2(*stats, summarize=-1,
output_stream=output_stream)
def _inspect_history_cache(self, cache, replica_id, step_num,
tensor_trace_order):
"""Generates a conditional print operation to log differences in tensor values.
Args:
cache: Tensor storing the trace results for the step.
replica_id: Tensor storing the replica id of the running core.
step_num: Step number.
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
Returns:
The Op to flush the cache to file.
"""
# Check if there are graph operations being profiled.
if not tensor_trace_order.traced_tensors:
logging.warn('TT history mode has no tensors in the cache to check.')
return control_flow_ops.no_op
stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num]
diffs = []
for tensor_name, cache_idx in sorted(
tensor_trace_order.tensorname_to_cache_idx.items(),
key=lambda item: item[1]):
tensor_to_write = cache[cache_idx, 0]
snapshot_variable = self._create_or_get_tensor_history_values_cache(
tensor_to_write.name, tensor_to_write.op.graph,
tensor_to_write.shape.as_list(), tensor_to_write.dtype)
with ops.control_dependencies([snapshot_variable]):
old_value = state_ops.assign_add(snapshot_variable, 0.0)
with ops.control_dependencies([old_value]):
new_value = math_ops.cast(tensor_to_write, dtypes.float32)
delta = math_ops.abs(math_ops.subtract(old_value, new_value))
updated = state_ops.assign(snapshot_variable, new_value)
diffs.append(delta)
with ops.control_dependencies([updated]):
new_value_from_var = state_ops.assign_add(snapshot_variable, 0.0)
stats.extend([
'\n', 'core:', replica_id, ',', 'step:', step_num, ',',
tensor_name, '-->', old_value, new_value_from_var, delta])
diff_stack = array_ops_stack.stack(diffs)
step_max = math_ops.reduce_max(diff_stack)
return cond.cond(
math_ops.greater(step_max, tensor_tracer_flags.DELTA_THRESHOLD.value),
lambda: logging_ops.print_v2(*stats, summarize=-1),
lambda: control_flow_ops.no_op()) # pylint: disable=unnecessary-lambda
def _get_outfile_suffix(self):
if remote_utils.is_remote_path(self._parameters.trace_dir):
return remote_utils.get_appendable_file_encoding()
else:
return ''
def _generate_flush_cache_op(self, num_replicas, on_tpu,
tensor_trace_order, graph):
"""Generates an Op that will flush the cache to file.
Args:
num_replicas: total number of replicas.
on_tpu: if the graph is executed on TPU.
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
graph: TensorFlow graph.
Returns:
The Op to flush the cache to file.
"""
def _flush_fun(cache, replica_id, step_num):
"""Flushes the cache to a file corresponding to replica_id."""
def _f(file_index):
"""Generates a func that flushes the cache to a file."""
def _print_cache():
"""Flushes the cache to a file."""
replica_str = ('%d' % file_index)
if self._parameters.trace_dir:
output_path = (os.path.join(self._parameters.trace_dir,
_COMPACT_TRACE_FILE_PREFIX)
+ replica_str + self._get_outfile_suffix())
output_stream = _OUTPUT_STREAM_ESCAPE + output_path
else:
output_stream = sys.stderr
new_step_line = _REPLICA_ID_TAG + replica_str
print_ops = []
if self._parameters.inspect_trace:
if self._num_signature_dimensions() > 1:
raise ValueError('Inspecting multi signatures are not supported.')
if self._parameters.trace_mode in (
tensor_tracer_flags.TRACE_MODE_HISTORY):
print_ops.append(
self._inspect_history_cache(
cache=cache,
replica_id=replica_id,
step_num=step_num,
tensor_trace_order=tensor_trace_order))
else:
print_ops.append(
self._inspect_summary_cache(
cache=cache,
replica_id=replica_id,
step_num=step_num,
output_stream=output_stream,
tensor_trace_order=tensor_trace_order))
else:
for i in range(self._num_signature_dimensions()):
print_ops.append(logging_ops.print_v2(
new_step_line, '\n',
cache[:, i], '\n',
summarize=-1,
output_stream=output_stream))
with ops.control_dependencies(print_ops):
return constant_op.constant(0).op
return _print_cache
def _eq(file_index):
return math_ops.equal(replica_id, file_index)
flush_op_cases = {}
flush_op_cases[_eq(0)] = _f(0)
for i in range(1, num_replicas):
if on_tpu and not self._parameters.collect_summary_per_core:
# If this is the case, the cache is already merged for all cores.
# Only first core flushes the cache.
flush_op_cases[_eq(i)] = control_flow_ops.no_op
else:
flush_op_cases[_eq(i)] = _f(i)
# Each replica needs to determine where to write their output.
# To do this, we check if replica_id is 0, then 1, ..., and then
# num_replicas - 1 statically; and return the corresponding static file
# name. We cannot simply set the file name in python, as replica_id is
# only known during tf runtime, and we cannot create dynamic filenames.
return control_flow_case.case(flush_op_cases, exclusive=True)
cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
if self._use_temp_cache():
cache_val = cache
else:
cache_val = cache.value()
if on_tpu:
# If we do not need to collect traces for all cores, merge and aggregate
# per core trace.
if not self._parameters.collect_summary_per_core:
cache_val = self.merge_caches_on_tpu(cache_val)
cache_val = self.aggregate_global_cache(cache_val)[0]
flush_op = tpu_replication.outside_compilation(
_flush_fun, cache_val, self._replica_id,
array_ops.identity(training_util.get_or_create_global_step()))
else:
global_step = training_util.get_or_create_global_step()
flush_op = _flush_fun(cache_val, self._replica_id, global_step)
if self._use_temp_cache():
with ops.control_dependencies([flush_op]):
return constant_op.constant(0).op
else:
# Re-initialize the local cache variable.
with ops.control_dependencies([flush_op]):
reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
dtype=cache.dtype,
shape=cache.shape)
assign_op = state_ops.assign(cache, reset_value).op
with ops.control_dependencies([assign_op]):
return constant_op.constant(0).op
def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu,
tensor_trace_order, graph):
"""Flushes the intermediate tensor values in the graph to the cache.
Args:
tensor_fetches: list of tensor results returned by the model_fn.
op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
on_tpu: if the graph is executed on TPU.
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
graph: TensorFlow graph.
Returns:
An identical copy of tensor_fetches.
"""
# Add a dependency to op and tensor fetches to make sure that all tracing
# ops are executed before flushing trace results.
if not tensor_trace_order.traced_tensors:
logging.warn('No tensor values being traced. No flush cache op added.')
return tensor_fetches
with ops.control_dependencies(op_fetches +
[tensor.op for tensor in tensor_fetches]):
flush_cache_op = self._generate_flush_cache_op(
self._tt_config.num_replicas, on_tpu, tensor_trace_order, graph)
return control_flow_ops.tuple(tensor_fetches,
control_inputs=[flush_cache_op])
def _process_tensor_fetches(self, tensor_fetches):
"""Check that tensor_fetches is not empty and have valid tensors."""
# If none or empty list.
if tensor_fetches is None:
raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
'None.')
if not isinstance(tensor_fetches, (list, tuple)):
tensor_fetches = [tensor_fetches]
elif not tensor_fetches:
raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
'empty list.')
fetches = []
for fetch in tensor_fetches:
if isinstance(fetch, ops.Tensor):
fetches.append(fetch)
else:
raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
return fetches
def _process_op_fetches(self, op_fetches):
"""Check that op_fetches have valid ops."""
if op_fetches is None:
return []
if not isinstance(op_fetches, (list, tuple)):
op_fetches = [op_fetches]
fetches = []
for fetch in op_fetches:
if isinstance(fetch, ops.Operation):
fetches.append(fetch)
elif isinstance(fetch, ops.Tensor):
fetches.append(fetch.op)
else:
logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
fetch)
return fetches
def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
"""Changes current_fetches' format, so that it matches input_fetches."""
if isinstance(input_fetches, ops.Tensor):
if len(current_fetches) != 1:
raise RuntimeError('Tensor tracer input/output fetches do not match.')
return current_fetches[0]
else:
if len(current_fetches) != len(current_fetches):
raise RuntimeError('Tensor tracer input/output fetches do not match.')
elif isinstance(input_fetches, tuple):
return tuple(current_fetches)
else:
return current_fetches
def _get_op_control_flow_context(self, op):
"""Returns the control flow of the given op.
Args:
op: tf.Operation for which the control flow context is requested.
Returns:
op_control_flow_context: which the is control flow context of the given
op. If the operation type is LoopExit, returns the outer control flow
context.
"""
# pylint: disable=protected-access
op_control_flow_context = op._control_flow_context
# pylint: enable=protected-access
if control_flow_util.IsLoopExit(op):
op_control_flow_context = op_control_flow_context.outer_context
return op_control_flow_context
def merge_caches_on_tpu(self, local_tpu_cache_tensor):
"""Merges the given caches on tpu.
Args:
local_tpu_cache_tensor: A local tensor that needs to be merged
by concanting data from other tpu cores.
Returns:
A merged tf.Tensor.
"""
x = array_ops.broadcast_to(
local_tpu_cache_tensor,
shape=[self._tt_config.num_replicas] +
local_tpu_cache_tensor.shape.as_list())
if tensor_tracer_flags.TT_SINGLE_CORE_SUMMARIES.value:
return x
return tpu_ops.all_to_all(
x, concat_dimension=0, split_dimension=0,
split_count=self._tt_config.num_replicas,
group_assignment=[list(range(self._tt_config.num_replicas))])
def aggregate_global_cache(self, global_tt_summary_cache):
"""Merges the given caches on tpu.
Args:
global_tt_summary_cache: The global tensor tracer summary cache tensor
with shape (num_cores, num_traced_tensors, num_traced_signatures). First
dimension corresponds to core_id, where global_tpu_cache_tensor[i]
correspond to the local cache from core-i.
Returns:
An aggregated tf.Tensor.
Raises:
RuntimeError: if there is no aggregate function defined for a signature.
"""
# Merge only statistics tensor, if it is any other tensor we simply,
# concatenate them.
agg_fn_map = self._parameters.get_signature_to_agg_fn_map()
signature_idx_map = self._signature_types()
aggregation_result = []
for signature, idx in sorted(signature_idx_map.items(),
key=operator.itemgetter(1)):
if signature not in agg_fn_map:
raise RuntimeError('No aggregation function is defined for '
'signature %s.' % signature)
# The dimensions of the statistics tensor is
# num_cores x num_traced_tensors x num_signatures
# value[:,:,idx] will return the portion of the tensor related
# to signature.
signature_tensor = global_tt_summary_cache[:, :, idx]
# Merge it along the first (core) axis.
agg_fn = agg_fn_map[signature]
agg_tensor = agg_fn(signature_tensor, axis=0)
aggregation_result.append(agg_tensor)
# Merge results corresponding to different signatures
merged_signatures = array_ops_stack.stack(aggregation_result)
# merged_signatures has dimensions
# num_signatures x num_traced_tensors, transpose it so that it
# will match with the original structure
# num_traced_tensors x num_signatures.
transposed_signatures = array_ops.transpose(merged_signatures)
# Expand 1 more dimension so that it will match with the expected
# structure num_cores x num_traced_tensors x num_signatures.
return array_ops.expand_dims(transposed_signatures, axis=0)
def _prepare_host_call_fn(self, processed_t_fetches,
op_fetches, graph, graph_summary_tag):
"""Creates a host call function that will write the cache as tb summary.
Args:
processed_t_fetches: List of tensor provided to session.run.
op_fetches: List of operations provided to session.run.
graph: TensorFlow graph.
graph_summary_tag: the summary_tag name for the given graph.
Raises:
ValueError if trace_dir is not set.
"""
if self._parameters.trace_dir is None:
raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
'--trace_dir=/model/dir')
def _write_cache(step, event_file_suffix=None, **kwargs):
"""Writes the given caches as tensor summary.
Args:
step: Step tensor with dimension [num_cores].
event_file_suffix: Event filename suffix tensor.
**kwargs: The dictionary of tensors that needs to be written as
summaries. Key and value pairs within kwargs correspond to the tag
name, and tensor content that will be written using summary.write.
The trace_modes that use this function are:
- summary: In summary mode, kwargs includes a single (tag, content)
pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
variable. The dimension of the signature_cache is:
num_cores x num_traced_tensors x num_signatures.
- full_tensor_summary: kwargs will include all traced tensors. Tag
and content correspond to the name of the tensor, and its actual
content.
Returns:
A tf.Operation that needs to be executed for the host call dependencies.
"""
file_suffix = _TT_EVENT_FILE_SUFFIX
if event_file_suffix is not None:
file_suffix = string_ops.string_join([file_suffix, event_file_suffix],
separator='.')
# TODO(deveci): Parametrize max_queue, so that flushing op can be called
# less frequently.
# Setting max_queue to 100 appears to be safe even when the number of
# iterations are much lower, as the destructor of the writer flushes it.
summary_write_ops = []
summary_writer = summary.create_file_writer_v2(
self._parameters.trace_dir,
filename_suffix=file_suffix,
max_queue=_TT_SUMMARY_MAX_QUEUE)
graph.add_to_collection(
TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)
step_value = step[0]
dt = step_value.dtype
# The step parameter to a summary write call must be 64-bit.
if dt.__ne__(dtypes.int64) and dt.__ne__(
dtypes.uint64) and dt.__ne__(dtypes.float64):
step_value = math_ops.cast(step_value, dtypes.int64)
with summary_writer.as_default():
summary_metadata = summary_pb2.SummaryMetadata(
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
for key, value in kwargs.items():
# Check whether we need to compute aggregated statistics that merge
# all cores statistics.
if not self._parameters.collect_summary_per_core:
# Merge only statistics tensor, if it is any other tensor we simply,
# concatenate them.
# Also, if there is only a single core (first dim. is 0), then skip
# aggregation.
if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1:
value = self.aggregate_global_cache(value)
with ops.control_dependencies([summary_writer.init()]):
summary_write_ops.append(summary.write(
_TT_SUMMARY_TAG + '/' + key + '.' + graph_summary_tag,
value, metadata=summary_metadata,
step=step_value))
return control_flow_ops.group(summary_write_ops)
global_step = training_util.get_or_create_global_step()
step = array_ops.reshape(global_step, [1])
self._host_call_fn = {}
host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]
caches_to_write = {}
with ops.control_dependencies(host_call_deps):
all_caches = self._cache_variable_for_graph(graph)
for cache_name, cache_variable in all_caches.items():
# Increase the cache rank by 1, so that when host call concatenates
# tensors from different replicas, we can identify them with [core_id].
new_cache_shape = [1]
new_cache_shape.extend(cache_variable.shape.as_list())
cache = array_ops.reshape(cache_variable, new_cache_shape)
caches_to_write[cache_name] = cache
# Add step to parameter dictionary.
caches_to_write['step'] = step
# Other options without adding step to parameter dictionary are
# * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
# considers caches_to_write as a single parameter, rather than a keyword
# parameters.
# * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
# a syntax error.
self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)
def host_call_deps_and_fn(self):
return self._host_call_fn
def get_traced_op_names(self):
"""Returns the set of traced op names."""
return self._traced_op_names
def _trace_execution(self, graph,
tensor_fetches,
op_fetches=None,
on_tpu=True):
"""Commong tracing function for both CPU and TPUs.
The caller function should set device_type, num_replicas,
num_replicas_per_host, num_hosts and replica_id before calling
_trace_execution.
Args:
graph: the graph of Ops executed on the TPU.
tensor_fetches: a (list,tuple,or a single object) of tensor fetches
returned by model_fn given to session.run. Function must be provided
with as least one tensor to fetch.
op_fetches: A list of op fetches returned by model_fn given to
session.run. op_fetches and tensor_fetches are used to determine the
nodes that will be executed. Can be None.
on_tpu: True if executing on TPU.
Returns:
tensor_fetches: an exact copy of tensor_fetches that has additional
dependencies.
Raises:
RuntimeError: If tensor_fetches is None or empty.
"""
def _cast_unsupported_dtypes(tensor):
"""Casts tensor to a supported type."""
if tensor.dtype.__eq__(dtypes.int64):
# outside-compilation doesn't support int64 input yet.
return math_ops.cast(tensor, dtypes.int32)
if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
dtypes.float16):
# Since host can't handle bf16, convert tensor to f32.
return math_ops.cast(tensor, dtypes.float32)
return tensor
trace_mode = self._parameters.trace_mode
device_type = self._tt_config.device_type
# pylint: disable=protected-access
self._outmost_context = graph._get_control_flow_context()
# pylint: enable=protected-access
analytics.track_usage('tensor_tracer', [trace_mode, device_type])
TensorTracer.check_device_type(device_type)
TensorTracer.check_trace_mode(device_type, trace_mode)
# Check in_tensor_fetches, and op_fetches and convert them to lists.
processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
op_fetches = self._process_op_fetches(op_fetches)
all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
# Filter out the operations that won't be executed.
# if fetches=None, then ops_in_exec_path = set(operations)
exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
all_fetches)
graph_summary_tag = _graph_summary_tag(graph)
# Write report file, and determine the traced tensors.
tensor_trace_order = self._determine_trace_and_create_report(
graph, exec_op_set, graph_summary_tag)
tensor_fetch_set = set(processed_t_fetches)
tracing_ops = []
sorted_exec_op_list = list(exec_op_set)
sorted_exec_op_list.sort(key=lambda op: op.name)
# Trace ops only if they are in the execution path.
for op in sorted_exec_op_list:
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
tensor_name = out_tensor.name
if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
continue
self._traced_op_names.add(op.name)
# Create the list of consumers before calling _preprocess_traced_tensor.
# Otherwise, adding control input below, will introduce a cycle in the
# graph.
consumers = out_tensor.consumers()
# Not all consumers may be in the exec path. Filter out the consumers
# to keep the graph simpler.
consumers = [cop for cop in consumers if cop in exec_op_set]
# If there is no consumer of the tensor, there is no need to trace it;
# unless the tensor itself is one of the fetches.
is_a_fetched_tensor = out_tensor in tensor_fetch_set
if (not consumers) and (not is_a_fetched_tensor):
continue
op_control_flow_context = self._get_op_control_flow_context(op)
if op_control_flow_context:
# pylint: disable=protected-access
graph._set_control_flow_context(op_control_flow_context)
# pylint: enable=protected-access
processed_tensors = self._preprocess_traced_tensor(out_tensor)
if on_tpu:
for signature in processed_tensors.keys():
processed_tensors[signature] = _cast_unsupported_dtypes(
processed_tensors[signature])
if self._use_tensor_values_cache():
# Use a small cache (either temp cache or tf local variable) to store
# the characteristics of the tensor.
if self._use_temp_cache():
cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
self._save_tensor_value_to_tmp_cache(cache_idx,
processed_tensors,
graph)
trace_op = None
else:
cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
trace_op = self._save_tensor_value_to_cache_op(cache_idx,
processed_tensors,
graph)
elif self._use_tensor_buffer():
if len(processed_tensors) != 1:
raise RuntimeError('Multiple stats are only allowed in compact '
'mode.')
processed_out_tensor = list(processed_tensors.values())[0]
# Store the whole tensor in a buffer.
trace_op = self._snapshot_tensor(processed_out_tensor)
else:
def tpu_wrap_trace_fn(tensor, out_tensor_name):
"""Wraps the trace_fn with outside compilation if on TPUs."""
tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
tensor_trace_order)
if on_tpu:
return tpu_replication.outside_compilation(
tensor_trace_fn, tensor)
else:
return tensor_trace_fn(tensor)
if len(processed_tensors) != 1:
raise RuntimeError('Multiple stats are only allowed in compact '
'mode.')
# Collecting multiple statistics are only supported in the summary
# mode that uses compact format(self._use_tensor_values_cache = true).
# Non-compact mode currently allows single stat per tensor.
processed_out_tensor = next(iter(processed_tensors.values()))
trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)
if op_control_flow_context:
# pylint: disable=protected-access
graph._set_control_flow_context(self._outmost_context)
# pylint: enable=protected-access
if trace_op:
if is_a_fetched_tensor:
tracing_ops.append(trace_op)
continue
# Add it to all consumers, as some consumers may not be executed if
# they are in a control flow.
for consumer_op in consumers:
# pylint: disable=protected-access
consumer_op._add_control_input(trace_op)
# pylint: enable=protected-access
# pylint: disable=protected-access
graph._set_control_flow_context(self._outmost_context)
# pylint: enable=protected-access
if tracing_ops:
# If we are tracing a fetched tensor, their dependency is stored in
# tracing_ops.
processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
control_inputs=tracing_ops)
if self._use_tensor_values_cache() or self._use_tensor_buffer():
if self._use_temp_cache():
# Create the temporary tf cache variable by concantanating all
# statistics.
graph_cache_var = self._cache_variable_for_graph(graph)
if graph not in self._temp_cache_var:
raise RuntimeError('graph is not in self._temp_cache_var')
graph_cache_var[_TT_SUMMARY_TAG] = array_ops_stack.stack(
self._temp_cache_var[graph], axis=0, name='stack_all_op_signatures')
if self._create_host_call():
self._prepare_host_call_fn(processed_t_fetches, op_fetches, graph,
graph_summary_tag)
if not on_tpu:
write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
cache_write_op = write_cache(**caches_to_write)
processed_t_fetches = control_flow_ops.tuple(
processed_t_fetches, control_inputs=[cache_write_op])
del self._host_call_fn[_TT_HOSTCALL_KEY]
elif self._parameters.flush_summaries_with_outside_compile:
write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
step = caches_to_write['step']
tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
if not self._parameters.collect_summary_per_core:
tt_core_summary = self.aggregate_global_cache(tt_core_summary)
def write_if_core_0(step, replica_id, tt_summary):
return cond.cond(
math_ops.equal(replica_id, 0),
lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda
tensor_tracer_summary=tt_summary),
control_flow_ops.no_op)
write_op = tpu_replication.outside_compilation(
write_if_core_0,
step=step,
replica_id=self._replica_id,
tt_summary=tt_core_summary)
processed_t_fetches = control_flow_ops.tuple(
processed_t_fetches, control_inputs=[write_op])
del self._host_call_fn[_TT_HOSTCALL_KEY]
else:
raise ValueError('Outside compiled flush in only supported for '
'summary mode')
else:
processed_t_fetches = self._flush_tensor_values_cache(
processed_t_fetches, op_fetches, on_tpu=on_tpu,
tensor_trace_order=tensor_trace_order,
graph=graph)
# processed_t_fetches is a list at this point. Convert it to the same
# format as given in tensor_fetches.
return self._convert_fetches_to_input_format(tensor_fetches,
processed_t_fetches)
def trace_tpu(self, graph,
tensor_fetches,
op_fetches=None,
num_replicas=None,
num_replicas_per_host=None,
num_hosts=None):
"""Traces the tensors generated by TPU Ops in a TF graph.
Args:
graph: the graph of Ops executed on the TPU.
tensor_fetches: a (list,tuple,or a single object) of tensor fetches
returned by model_fn given to session.run. Function must be provided
with as least one tensor to fetch.
op_fetches: A list of op fetches returned by model_fn given to
session.run. op_fetches and tensor_fetches are used to determine the
nodes that will be executed. Can be None.
num_replicas: number of replicas used on the TPU.
num_replicas_per_host: number of replicas per TPU host.
num_hosts: total number of TPU hosts.
Returns:
tensor_fetches: an exact copy of tensor_fetches that has additional
dependencies.
"""
if isinstance(graph, func_graph.FuncGraph) or isinstance(
graph, function._FuncGraph): # pylint: disable=protected-access
logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
'Ignoring tracing.')
return tensor_fetches
if graph in TensorTracer._traced_graphs:
logging.warning('Graph is already rewritten with tensor tracer, ignoring '
'multiple calls.')
return tensor_fetches
else:
TensorTracer._traced_graphs.add(graph)
# Reset the parameters in case parameters are changed.
self._parameters = tensor_tracer_flags.TTParameters()
self._tt_config.device_type = _DEVICE_TYPE_TPU
self._tt_config.num_replicas = num_replicas
self._tt_config.num_replicas_per_host = num_replicas_per_host
self._tt_config.num_hosts = num_hosts
if self._tt_config.num_replicas is not None:
if self._tt_config.num_replicas_per_host is None:
self._tt_config.num_replicas_per_host = 8
if self._tt_config.num_hosts is None:
self._tt_config.num_hosts = (
num_replicas // self._tt_config.num_replicas_per_host +
(num_replicas % self._tt_config.num_replicas_per_host > 0))
if self._parameters.graph_dump_path:
graph_io.write_graph(graph, self._parameters.graph_dump_path,
'graph_before_tt.pbtxt')
with graph.as_default():
self._add_replica_id_to_graph()
tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
on_tpu=True)
if self._parameters.graph_dump_path:
graph_io.write_graph(graph, self._parameters.graph_dump_path,
'graph_after_tt.pbtxt')
return tensor_fetches
def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
"""Traces the tensors generated by CPU Ops in a TF graph.
Args:
graph: the graph of Ops executed on the CPU.
tensor_fetches: a (list,tuple,or a single object) of tensor fetches
returned by model_fn given to session.run. Function must be provided
with as least one tensor to fetch.
op_fetches: A list of op fetches returned by model_fn given to
session.run. op_fetches and tensor_fetches are used to determine the
nodes that will be executed. Can be None.
Returns:
tensor_fetches: an exact copy of tensor_fetches that has additional
dependencies.
"""
if isinstance(graph, func_graph.FuncGraph) or isinstance(
graph, function._FuncGraph): # pylint: disable=protected-access
logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
'Ignoring tracing.')
return tensor_fetches
if graph in TensorTracer._traced_graphs:
logging.warning('Graph is already rewritten with tensor tracer, ignoring '
'multiple calls.')
return tensor_fetches
else:
TensorTracer._traced_graphs.add(graph)
# Reset the parameters in case parameters are changed.
self._parameters = tensor_tracer_flags.TTParameters()
self._tt_config.device_type = _DEVICE_TYPE_CPU
self._tt_config.num_replicas = 1
self._tt_config.num_replicas_per_host = 1
self._tt_config.num_hosts = 1
self._replica_id = 0
if self._parameters.graph_dump_path:
graph_io.write_graph(graph, self._parameters.graph_dump_path,
'graph_before_tt.pbtxt')
with graph.as_default():
tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
on_tpu=False)
if self._parameters.graph_dump_path:
graph_io.write_graph(graph, self._parameters.graph_dump_path,
'graph_after_tt.pbtxt')
return tensor_fetches