blob: d42417bde25cde19900e9ab9b75abf85f6fa1a80 [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import timeit
from tensorflow.python.eager import context
from tensorflow.python.eager.polymorphic_function import polymorphic_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@contextlib.contextmanager
def options(optimizer_options):
old_opts = context.context().get_optimizer_experimental_options()
context.context().set_optimizer_experimental_options(optimizer_options)
try:
yield
finally:
context.context().set_optimizer_experimental_options(old_opts)
class FunctionTest(test.TestCase):
@test_util.run_v2_only
def test_grappler_optimization(self):
@polymorphic_function.function
def brancher(inp):
x = constant_op.constant(1)
for _ in range(1000):
if inp:
x = x + constant_op.constant(1)
else:
x = x + constant_op.constant(2)
return x
@polymorphic_function.function
def brancher_true():
left = constant_op.constant(True)
x = constant_op.constant(1)
for _ in range(1000):
if left:
x = x + constant_op.constant(1)
else:
x = x + constant_op.constant(2)
return x
x = constant_op.constant(True)
self.assertEqual(brancher(x), brancher_true()) # Trace each function once.
benchmark = min(timeit.repeat(lambda: brancher(x), repeat=5, number=100))
opt_benchmark = min(timeit.repeat(brancher_true, repeat=5, number=100))
# Constant folded execution is usually 15 - 20 times faster. Here we check
# for a 3x speedup to account for various machines the test might run on.
self.assertLess(opt_benchmark * 3, benchmark)
@test_util.run_v2_only
def test_small_constants_optimization_with_grappler(self):
def func(inp):
x = constant_op.constant(1)
for _ in range(1000):
if inp:
x = x + constant_op.constant(1)
else:
x = x + constant_op.constant(2)
return x
brancher = polymorphic_function.function(func)
brancher_opt = polymorphic_function.function(
func, experimental_attributes={'runtime_constant_optimization': True}
)
# Trace each function once.
with ops.device_v2('CPU'):
x = constant_op.constant(True)
self.assertEqual(brancher(x), brancher_opt(x))
benchmark = min(timeit.repeat(lambda: brancher(x), repeat=5, number=100))
opt_benchmark = min(
timeit.repeat(lambda: brancher_opt(x), repeat=5, number=100)
)
# Constant folded execution is usually 15 - 20 times faster. Here we check
# for a 2x speedup to account for various machines the test might run on.
# Specially the kokoro machines seems to run much slower.
self.assertLess(opt_benchmark * 2, benchmark)
@test_util.run_v2_only
@test_util.run_gpu_only
def test_small_constants_optimization_disabled(self):
@polymorphic_function.function(
experimental_attributes={'runtime_constant_optimization': True}
)
def func(inp):
return inp
x = constant_op.constant(True)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
(
'Expecting boolean tensor to be on host when'
' small_constants_optimizer is enabled.'
),
):
func(x)
@test_util.run_v2_only
def test_small_constants_optimization_invalid_input(self):
@polymorphic_function.function(
experimental_attributes={'runtime_constant_optimization': True}
)
def func(inp):
return inp
with ops.device_v2('CPU'):
x = constant_op.constant([True, True])
# runtime_constant_optimization should not crash when the tf.function
# is passed in a boolean tensor having > 1 element.
self.assertAllEqual(func(x), x)
@test_util.run_v2_only
def test_small_constants_optimization_without_grappler(self):
def func(inp):
x = constant_op.constant(1)
for _ in range(1000):
if inp:
x = x + constant_op.constant(1)
else:
x = x + constant_op.constant(2)
return x
brancher = polymorphic_function.function(func)
brancher_opt = polymorphic_function.function(
func, experimental_attributes={'runtime_constant_optimization': True}
)
# Trace each function once.
with ops.device_v2('CPU'):
x = constant_op.constant(True)
self.assertEqual(brancher(x), brancher_opt(x))
# Disable grappler and check that performance is still good with
# small_constants_optimizer.
with options({'disable_meta_optimizer': True}):
benchmark = min(timeit.repeat(lambda: brancher(x), repeat=5, number=100))
opt_benchmark = min(
timeit.repeat(lambda: brancher_opt(x), repeat=5, number=100)
)
# Constant folded execution is usually 150x times faster (against a base
# that has no grappler optimization). Here we check
# for a 5x speedup to account for various machines the test might run on.
# Specially the kokoro machines seems to run much slower.
self.assertLess(opt_benchmark * 5, benchmark)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()