| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ======================================================================== |
| """Utilities to handle tensor tracer parameters.""" |
| |
| |
| import os |
| import os.path |
| import re |
| from absl import flags |
| from tensorflow.python.ops import linalg_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.platform import tf_logging as logging |
| |
| TRACE_MODE_PART_TENSOR = 'part-tensor' |
| TRACE_MODE_FULL_TENSOR = 'full-tensor' |
| TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary' |
| |
| TRACE_MODE_NAN_INF = 'nan-inf' |
| TRACE_MODE_NORM = 'norm' |
| TRACE_MODE_MAX_ABS = 'max-abs' |
| TRACE_MODE_SUMMARY = 'summary' |
| TRACE_MODE_HISTORY = 'history' |
| # summary mode to collects a finite set of signatures for each traced tensor, |
| # (such as norm, max, min, mean) and dumps it using tb summaries. |
| |
| # Full tensor mode dumps the whole tensor values for the traced tensors without |
| # any processing on them; using tb summaries. |
| |
| _SUBMODE_BRIEF = 'brief' |
| _SUBMODE_DETAILED = 'detailed' |
| |
| _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") |
| _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') |
| _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') |
| _FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') |
| |
| FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' |
| FLAG_NAME_ENABLE = 'enable' |
| FLAG_NAME_TRACE_MODE = 'trace_mode' |
| FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' |
| FLAG_NAME_SUBMODE = 'submode' |
| FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' |
| FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' |
| FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' |
| FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' |
| FLAG_NAME_TRACE_LEVEL = 'trace_level' |
| FLAG_NAME_TRACE_DIR = 'trace_dir' |
| FLAG_NAME_REPORT_FILE = 'report_file' |
| FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' |
| FLAG_NAME_OP_RANGE = 'op_range' |
| # Folder to dump the pre (before tensor tracer updates) and post graphs (after |
| # tensor tracer updates). |
| FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' |
| FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' |
| FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' |
| FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' |
| FLAG_NAME_INSPECT_TRACE = 'inspect_trace' |
| FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory' |
| FLAG_FLUSH_SUMMARY = 'flush_summaries' |
| |
| |
| VALID_FLAG_NAMES = [ |
| FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE, |
| FLAG_NAME_TRACE_SCALAR_OPS, |
| FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES, |
| FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES, |
| FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR, |
| FLAG_NAME_REPORT_FILE, |
| FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, |
| FLAG_NAME_OP_RANGE, |
| FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, |
| FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, |
| FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR, |
| FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY, |
| ] |
| |
| _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') |
| _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' |
| |
| _TT_DEFAULT_TRACE_LEVEL = 3 |
| _TT_PREFIX = 'tensor_tracer' |
| |
| _TT_NORM = 'norm' |
| _TT_MAX = 'max' |
| _TT_MAX_ABS = 'max-abs' |
| _TT_MIN = 'min' |
| _TT_SPARSITY = 'sparsity' |
| _TT_MEAN = 'mean' |
| _TT_VAR = 'var' |
| _TT_SIZE = 'size' |
| |
| TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM) |
| TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX) |
| TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS) |
| TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN) |
| TT_SUMMARY_SPARSITY = '%s_%s' % (_TT_PREFIX, _TT_SPARSITY) |
| TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN) |
| TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR) |
| TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE) |
| |
| TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN, |
| TT_SUMMARY_SPARSITY, TT_SUMMARY_MEAN, TT_SUMMARY_VAR, |
| TT_SUMMARY_SIZE, TT_SUMMARY_MAX_ABS) |
| |
| FLAGS = flags.FLAGS |
| |
| DELTA_THRESHOLD = flags.DEFINE_float( |
| 'delta_threshold', |
| default=0.5, |
| help=('Log if history based diff crosses this threshold.')) |
| TT_CHECK_FILTER = flags.DEFINE_bool( |
| 'tt_check_filter', |
| default=False, |
| help='Terminate early to check op name filtering.') |
| TT_SINGLE_CORE_SUMMARIES = flags.DEFINE_bool( |
| 'tt_single_core_summaries', |
| default=False, |
| help='Report single core metric and avoid aggregation.') |
| |
| |
| class TTParameters(object): |
| """A class that handles the parameters of Tensor Tracer.""" |
| |
| def __init__(self, env=None): |
| if env: |
| self._env = env |
| else: |
| self._env = os.environ |
| self._validate_flag_names() |
| self.trace_mode = self._get_trace_mode() |
| self.submode = self._get_submode() |
| self.trace_dir = self._get_trace_dir() |
| self.report_file_path = self._get_report_filepath() |
| self.op_range = self._get_op_range() |
| self.excluded_opname_re_list = self._flag_value_to_re_list( |
| FLAG_NAME_EXCLUDED_OPNAMES) |
| self.excluded_optype_re_list = self._flag_value_to_re_list( |
| FLAG_NAME_EXCLUDED_OPTYPES) |
| |
| self.included_opname_re_list = self._flag_value_to_re_list( |
| FLAG_NAME_INCLUDED_OPNAMES) |
| self.included_optype_re_list = self._flag_value_to_re_list( |
| FLAG_NAME_INCLUDED_OPTYPES) |
| |
| self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) |
| self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF, |
| TRACE_MODE_NORM, |
| TRACE_MODE_HISTORY, |
| TRACE_MODE_MAX_ABS, |
| TRACE_MODE_SUMMARY) |
| self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) |
| self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE) |
| self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR) |
| |
| _, self.graph_dump_path = self.get_flag_value( |
| FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS) |
| self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL, |
| _TT_DEFAULT_TRACE_LEVEL) |
| self.summary_signatures = self._get_summary_signatures() |
| self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE) |
| # TODO(b/199284834): Will be resolved with referenced bug. |
| if self.collect_summary_per_core: |
| logging.warning('Aggregate signatures are approximate for mean, variance' |
| ' and sparsity.') |
| self.flush_summaries_with_outside_compile = self.is_flag_on( |
| FLAG_FLUSH_SUMMARY) |
| # Do not produce errors or warnings if Tensor Tracer is not enabled. |
| if self.is_enabled(): |
| self._check_flag_errors() |
| |
| def _check_flag_errors(self): |
| if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY): |
| if not self.trace_dir: |
| raise ValueError('trace_dir must be explicitly provided in ' |
| 'TENSOR_TRACER_FLAGS when summary mode is used.') |
| |
| def _get_report_filepath(self): |
| """Sets the path of the output report file.""" |
| |
| found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE) |
| if found and report_file_path and self.use_test_undeclared_outputs_dir(): |
| if os.path.isabs(report_file_path): |
| raise ValueError('If use_test_undeclared_outputs_dir is set,' |
| 'report_file_path cannot be an absolute path (%s)' |
| %report_file_path) |
| outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) |
| report_file_path = os.path.join(outputs_dir, report_file_path) |
| return report_file_path |
| |
| def _get_op_range(self): |
| """Sets the index range of the Ops that we will consider tracing.""" |
| found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE) |
| if not found or not op_range: |
| op_range = (-1, -1) # this means including all ops. |
| return op_range |
| match = _OP_RANGE_PAT.match(op_range) |
| if not match: |
| op_range = (-1, -1) # this means including all ops. |
| return op_range |
| op_range = (int(match.group(1)), int(match.group(2))) |
| return op_range |
| |
| def _get_trace_dir(self): |
| found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR) |
| if found and trace_dir and self.use_test_undeclared_outputs_dir(): |
| raise ValueError( |
| 'Cannot not use --%s and --%s at the same time' % |
| (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) |
| if self.use_test_undeclared_outputs_dir(): |
| trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) |
| return trace_dir |
| |
| def _get_trace_mode(self): |
| """Checks if the given trace mode is valid.""" |
| |
| found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE) |
| if not found or not trace_mode: |
| trace_mode = TRACE_MODE_NORM |
| valid_trace_modes = [ |
| TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR, |
| TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, |
| TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY, |
| TRACE_MODE_HISTORY |
| ] |
| if trace_mode not in valid_trace_modes: |
| raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' |
| 'Valid trace modes are: %s'%(trace_mode, |
| valid_trace_modes)) |
| return trace_mode |
| |
| def is_brief_mode(self): |
| return self.submode == _SUBMODE_BRIEF |
| |
| def _get_submode(self): |
| """Checks if the given submode is valid.""" |
| |
| found, submode = self.get_flag_value(FLAG_NAME_SUBMODE) |
| if not found or not submode: |
| submode = _SUBMODE_DETAILED |
| if not submode: |
| return |
| valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] |
| if submode not in valid_submodes: |
| raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' |
| 'Valid submodes are: %s'%(submode, |
| valid_submodes)) |
| return submode |
| |
| @staticmethod |
| def match_next_flag(tt_flags, pos): |
| """Returns the match for the next TensorTracer flag. |
| |
| Args: |
| tt_flags: a string that contains the flags. |
| pos: where in flags to start the search. |
| |
| Returns: |
| A pair where the first element is the regular-expression |
| match found and the second element indicates if the match |
| has a value. |
| """ |
| |
| match = _FLAG_DOUBLE_QUOTE_PAT.match(tt_flags, pos) |
| if match: |
| return match, True |
| match = _FLAG_SINGLE_QUOTE_PAT.match(tt_flags, pos) |
| if match: |
| return match, True |
| match = _FLAG_NO_QUOTE_PAT.match(tt_flags, pos) |
| if match: |
| return match, True |
| match = _FLAG_NO_EQUAL_PAT.match(tt_flags, pos) |
| if match: |
| # The flag is found but is not given a value. |
| return match, False |
| # The flag is not found. |
| return None, False |
| |
| def _validate_flag_names(self): |
| """Validates if the TensorTrace flags passed are valid.""" |
| tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) |
| if not tensor_tracer_flags: |
| return |
| pos = 0 |
| while True: |
| match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos) |
| if not match: |
| break |
| flag_name = match.group(1) |
| if flag_name not in VALID_FLAG_NAMES: |
| raise ValueError( |
| 'The flag name "%s" passed via the environment variable "%s" ' |
| 'is invalid. Valid flag names are:' |
| '\n%s' % (flag_name, FLAGS_ENV_VAR, VALID_FLAG_NAMES)) |
| pos = match.end() |
| |
| def _supported_signatures(self): |
| """Returns a tuple of supported signatures.""" |
| return TT_SUMMARY_SIGNATURES |
| |
| def _get_summary_signatures(self): |
| """Verifies and returns the summary signatures. |
| |
| Returns: |
| A dictionary of the signature identifiers {signature: index} that will be |
| computed when trace_mode is summary. |
| """ |
| signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES) |
| supported_signatures = self._supported_signatures() |
| |
| tt_signatures = [] |
| for signature in signatures: |
| signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature) |
| if signature in supported_signatures: |
| tt_signatures.append(signature) |
| elif signature_with_prefix in supported_signatures: |
| tt_signatures.append(signature_with_prefix) |
| else: |
| logging.warning('Unknown signature:%s. Supported signatures: %s' % |
| (signature, supported_signatures)) |
| if not tt_signatures: |
| # Default case collects norm and max only. |
| return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1} |
| else: |
| return {signature: idx for idx, signature in enumerate(tt_signatures)} |
| |
| def get_signature_to_agg_fn_map(self): |
| """Returns a map that contains the aggregate function for each signature.""" |
| # TODO(b/199284834): Aggregations are not accurate for mean and sparsity if |
| # cores have a different number of elements. Variance uses the maximal core |
| # variance. |
| return {TRACE_MODE_NORM: linalg_ops.norm, |
| TRACE_MODE_HISTORY: math_ops.reduce_max, |
| TRACE_MODE_MAX_ABS: math_ops.reduce_max, |
| TRACE_MODE_NAN_INF: math_ops.reduce_max, |
| TT_SUMMARY_NORM: linalg_ops.norm, |
| TT_SUMMARY_MAX: math_ops.reduce_max, |
| TT_SUMMARY_MAX_ABS: |
| lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t), # pylint: disable=g-long-lambda |
| axis=axis), |
| TT_SUMMARY_MIN: math_ops.reduce_min, |
| # Exact if each part has the same number of values. |
| TT_SUMMARY_SPARSITY: math_ops.reduce_mean, |
| TT_SUMMARY_MEAN: math_ops.reduce_mean, |
| TT_SUMMARY_VAR: math_ops.reduce_max, # Simply reduce max variance. |
| TT_SUMMARY_SIZE: math_ops.reduce_sum} |
| |
| def _flag_value_as_list(self, wanted_flag_name): |
| """Returns the string list of a TensorTracer flag. |
| |
| Args: |
| wanted_flag_name: the name of the flag we are looking for. |
| |
| Returns: |
| The list value of the flag. |
| """ |
| string_value_list = [] |
| found, flag_value = self.get_flag_value(wanted_flag_name) |
| |
| if found: |
| assert flag_value is not None |
| string_value_list = flag_value.split(',') |
| return string_value_list |
| |
| def _flag_value_as_int_list(self, wanted_flag_name): |
| """Returns the integer list of a TensorTracer flag. |
| |
| Args: |
| wanted_flag_name: the name of the flag we are looking for. |
| |
| Returns: |
| the value of the flag. |
| Raises: |
| RuntimeError: If supposedly deadcode is reached. |
| """ |
| int_list = [] |
| found, flag_value = self.get_flag_value(wanted_flag_name) |
| |
| if found and flag_value: |
| try: |
| integer_values = flag_value.split(',') |
| int_list = [int(int_val) for int_val in integer_values] |
| except ValueError: |
| logging.warning('Cannot convert %s to int for flag %s', int_list, |
| wanted_flag_name) |
| return int_list |
| |
| def _get_flag_int_value(self, wanted_flag_name, default_value): |
| """Returns the int value of a TensorTracer flag. |
| |
| Args: |
| wanted_flag_name: the name of the flag we are looking for. |
| default_value: the default value for the flag, if not provided. |
| Returns: |
| the value of the flag. |
| Raises: |
| RuntimeError: If supposedly deadcode is reached. |
| """ |
| flag_int_value = default_value |
| found, flag_value = self.get_flag_value(wanted_flag_name) |
| |
| if found: |
| try: |
| flag_int_value = int(flag_value) |
| except ValueError: |
| logging.warning('Cannot convert %s to int for flag %s' % ( |
| flag_int_value, wanted_flag_name)) |
| return flag_int_value |
| |
| def get_flag_value(self, wanted_flag_name): |
| """Returns the value of a TensorTracer flags. |
| |
| Args: |
| wanted_flag_name: the name of the flag we are looking for. |
| |
| Returns: |
| A pair where the first element indicates if the flag is |
| found and the second element is the value of the flag. |
| |
| Raises: |
| RuntimeError: If supposedly deadcode is reached. |
| """ |
| |
| tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) |
| if not tensor_tracer_flags: |
| return False, None |
| pos = 0 |
| while True: |
| match, has_value = TTParameters.match_next_flag( |
| tensor_tracer_flags, pos) |
| if not match: |
| return False, None |
| flag_name = match.group(1) |
| if has_value: |
| flag_value = match.group(2) |
| else: |
| flag_value = None |
| if flag_name == wanted_flag_name: |
| return True, flag_value |
| pos = match.end() |
| raise RuntimeError('Invalid tensor tracer flag. Could not recognize %s.' % |
| flag_name) |
| |
| def _flag_value_to_re_list(self, flag_name): |
| """Converts list of strings to compiled RE.""" |
| |
| re_list = [] |
| found, flag_value = self.get_flag_value(flag_name) |
| if not found or not flag_value: |
| return re_list |
| list_of_values = flag_value.split(',') |
| for v in list_of_values: |
| r = re.compile(v) |
| re_list.append(r) |
| return re_list |
| |
| def is_flag_on(self, flag_name): |
| """Returns True if the given flag is on.""" |
| |
| found, flag_value = self.get_flag_value(flag_name) |
| if not found: |
| return False |
| if flag_value is None: |
| return True |
| # Depends on the flag value. |
| flag_value = flag_value.lower() |
| enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] |
| return enabled |
| |
| def is_enabled(self): |
| """Returns True if TensorTracer is enabled.""" |
| |
| if self.is_flag_on(FLAG_NAME_ENABLE): |
| logging.debug('Tensor Tracer is enabled with flags %s.', |
| self._env.get(FLAGS_ENV_VAR)) |
| return True |
| else: |
| return False |
| |
| def use_test_undeclared_outputs_dir(self): |
| """Decides the output directory of the report and trace files. |
| |
| Args: |
| None. |
| |
| Returns: |
| True if the output files should be written to the |
| test-undeclared-outputs-directory defined via an |
| env variable. |
| """ |
| |
| return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) |