blob: 25350dedb03e958de6e02b5dc4366f7f83e6646d [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for case statements in XLA."""
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import test
class CaseTest(xla_test.XLATestCase):
def testCaseBasic(self):
@def_function.function(jit_compile=True)
def switch_case_test(branch_index):
def f1():
return array_ops.constant(17)
def f2():
return array_ops.constant(31)
def f3():
return array_ops.constant(-1)
return control_flow_switch_case.switch_case(
branch_index, branch_fns={
0: f1,
1: f2
}, default=f3)
with ops.device(self.device):
self.assertEqual(switch_case_test(array_ops.constant(0)).numpy(), 17)
self.assertEqual(switch_case_test(array_ops.constant(1)).numpy(), 31)
self.assertEqual(switch_case_test(array_ops.constant(2)).numpy(), -1)
self.assertEqual(switch_case_test(array_ops.constant(3)).numpy(), -1)
def testBranchIsPruned(self):
@def_function.function(jit_compile=True)
def switch_case_test():
branch_index = array_ops.constant(0)
def f1():
return array_ops.constant(17)
def f2():
# Some operations that XLA cannot compile.
image_ops.decode_image(io_ops.read_file('/tmp/bmp'))
return array_ops.constant(31)
# This tests that we do not try to compile all branches if the branch
# index in trivially constant.
return control_flow_switch_case.switch_case(
branch_index, branch_fns={
0: f1,
1: f2
}, default=f2)
with ops.device(self.device):
self.assertEqual(switch_case_test().numpy(), 17)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()