blob: 7fdc8bfb7a34cd9e31be94cb81086f766bfb6afd [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for summary op transformations."""
import os
import os.path
from absl import flags
from tensorflow.core.function.runtime_client import runtime_client
from tensorflow.core.util import event_pb2
from tensorflow.python.data.ops import readers
from tensorflow.python.eager.polymorphic_function import polymorphic_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
FLAGS = flags.FLAGS
class SummaryOpsTransformationTest(test.TestCase):
def setUp(self):
super().setUp()
self.summary_dir = os.path.join(FLAGS.test_tmpdir, 'mylogs')
# Clean up any summary directories before starting the test so we can
# validate that summaries are only written when enabled.
try:
gfile.DeleteRecursively(self.summary_dir)
except Exception: # pylint: disable=broad-exception-caught
pass
@test_util.run_v2_only
def test_strip_summary_ops(self):
def normalize_while_node(fndef):
"""Helper method to normalize the while node for comparison."""
for node in fndef.node_def:
if node.op == 'While':
# The names of the nested functions are expected to be different
# because they will have a uid appended to them.
node.attr['body'].func.name = 'while_body'
node.attr['cond'].func.name = 'while_cond'
# The summary_writer and `inlcude_summary` args are expected to be
# passed in and out of the transformed function as we do not modify
# the function signatures.
# Expect a mismatch in input and output types/shapes.
node.attr['T'].ClearField('list')
node.attr['output_shapes'].ClearField('list')
expected_inputs = {
'write_summary_summary_cond_input_1',
'record_summary',
}
if 'record_summary' not in node.input:
continue
inputs = node.input
node.ClearField('input')
node.input.extend(inp for inp in inputs if inp not in expected_inputs)
node.attr['_num_original_outputs'].i -= 2
return fndef
def normalize_fdef(fndef):
"""Method to normalize the tf.function's FunctionDefs for comparison."""
# Normalize the names for comparison as they have a uid appended.
fndef.signature.name = '__inference_add'
# The summary writer is expected to be passed into the transformed fn.
inputs = fndef.signature.input_arg
fndef.signature.ClearField('input_arg')
fndef.signature.input_arg.extend(
inp
for inp in inputs
if inp.name != 'write_summary_summary_cond_input_1'
)
# The disable_summaries_at_runtime attr is expected to be cleared.
fndef.attr['disable_summaries_at_runtime'].ClearField('list')
return fndef
writer = summary_ops_v2.create_file_writer_v2(self.summary_dir)
var = variables.Variable(1.0)
def remove_writer_attr(fndef):
arg_attr = fndef.arg_attr
attr_idx = None
# tf.function uses TraceType to create placeholder for captures.
# An extra "_user_specified_name" attr will be added to the placeholder.
for idx in arg_attr:
if arg_attr[idx].attr['_user_specified_name'].s == b'input_1':
attr_idx = idx
break
if attr_idx is not None:
# Copy subsequent arg_attr to ensure indexes are continuous
for idx in range(attr_idx, len(arg_attr) - 1):
fndef.arg_attr[idx].CopyFrom(fndef.arg_attr[idx + 1])
del fndef.arg_attr[len(arg_attr) - 1]
return fndef
@polymorphic_function.function(
autograph=False,
experimental_attributes={
'disable_summaries_at_runtime': ['record_summary', False]
},
)
def add(x, y, record_summary, include_summary):
def body(step, result):
result += math_ops.cast(step, dtypes.float32)
var.assign(result)
if include_summary:
# Perform a summary write in a nested function.
with writer.as_default():
summary_ops_v2.set_step(step)
summary_ops_v2.write('my_metric', result, step=step)
writer.flush()
return math_ops.add(step, 1)
result = math_ops.add(x, y)
step = constant_op.constant(0, dtypes.int64)
with summary_ops_v2.record_if(record_summary):
if include_summary:
# Perform a summary write in the main function body.
with writer.as_default():
summary_ops_v2.set_step(step)
summary_ops_v2.write('my_metric', result, step=step)
writer.flush()
step = math_ops.add(step, 1)
loop_cond = lambda i: math_ops.less(i, 3)
loop_body = lambda i: body(i, result)
step = while_loop.while_loop_v2(loop_cond, loop_body, [step])
var.assign(result)
return result
one = constant_op.constant(1.0, dtypes.float32)
inputs_with_summaries = [one, one, constant_op.constant(True), True]
inputs_without_summaries = [one, one, constant_op.constant(False), False]
inputs_without_summaries_at_runtime = [
one,
one,
constant_op.constant(False),
True,
]
# Ensure the result of `add` is the same with and without summaries.
self.assertEqual(
add(*inputs_with_summaries), add(*inputs_without_summaries)
)
# Ensure the result of `add` is the same when summaries have been stripped
# at trace time.
self.assertEqual(
add(*inputs_without_summaries_at_runtime),
add(*inputs_without_summaries),
)
# Force a trace of `add` where summaries have been stripped at trace time.
expected = add.get_concrete_function(*inputs_without_summaries).function_def
# Extract the trace of `add` where summaries have been stripped in the
# runtime.
function_name = add.get_concrete_function(
*inputs_without_summaries_at_runtime
).function_def.signature.name
ctx = runtime_client.GlobalPythonEagerContext()
rt = runtime_client.Runtime(ctx)
fndef = rt.GetFunctionProto(function_name + '__instance__no_summaries')
# Normalize the fndefs and compare them for equivalence.
fndef = normalize_fdef(normalize_while_node(fndef))
fndef = remove_writer_attr(fndef)
expected = normalize_fdef(normalize_while_node(expected))
self.assertProtoEquals(expected, fndef)
# Verify that summaries were only written when executing with the
# `inputs_with_summaries` argument.
num_summary_events = 0
summary_files = [
os.path.join(self.summary_dir, sf)
for sf in gfile.ListDirectory(self.summary_dir)
]
for record in readers.TFRecordDatasetV2(
filenames=summary_files
).as_numpy_iterator():
event = event_pb2.Event()
event.ParseFromString(record)
if event.HasField('summary'):
num_summary_events += 1
# 3 Events are written by `add` when summaries are enabled.
self.assertEqual(num_summary_events, 3)
if __name__ == '__main__':
test.main()