blob: db767f0d5546278d1b3e3549d557ca12ed4e50a6 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.cond in XLA."""
from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import session
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
@test_util.with_control_flow_v2
class CondTest(xla_test.XLATestCase):
def testCondAndTensorArrayInDefun(self):
# TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@def_function.function
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = cond.cond(
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t = f()
self.assertAllEqual([5.], self.evaluate(output_t))
xla_context.Exit()
def testCondAndTensorArrayInDefun_constFolding(self):
g = ops.Graph()
with session.Session(graph=g), g.as_default(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@def_function.function
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = cond.cond(
constant_op.constant(False),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t = f()
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testCondAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = cond.cond(
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([5.], self.evaluate(output_t))
xla_context.Exit()
def testCondConstPropagation(self):
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = array_ops.placeholder(dtypes.int32)
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[p]
def if_false():
return 5.
output = cond.cond(
constant_op.constant(True), if_true, if_false)
self.assertAllEqual(1.,
sess.run(output, feed_dict={
x: [0., 1., 2.],
p: 1
}))
xla_context.Exit()
def testCondConstPropagation_xlaCompile(self):
self.skipTest("b/132430685")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3])
p = constant_op.constant(1)
def f():
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[p]
def if_false():
return 5.
return cond.cond(
constant_op.constant(True), if_true, if_false)
output = xla.compile(f)
self.assertAllEqual(1., self.evaluate(output))
xla_context.Exit()
def testCondConstPropagation_errorMsg(self):
self.skipTest("b/132430685")
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32)
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[:p]
def if_false():
return array_ops.fill([p], 5.)
output = cond.cond(
constant_op.constant(True), if_true, if_false)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"must be a compile-time constant"):
sess.run(
output, feed_dict={
x: [0., 1., 2.],
})
xla_context.Exit()
def testCondConstPropagation_errorMsg_xlaCompile(self):
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32)
condition = math_ops.cast(
random_ops.random_uniform([], minval=0, maxval=2, dtype=dtypes.int32),
dtypes.bool)
def f():
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[:p]
def if_false():
return array_ops.fill([p], 5.)
return cond.cond(condition, if_true, if_false)
output = xla.compile(f)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"must be a compile-time constant"):
sess.run(
output, feed_dict={
x: [0., 1., 2.],
})
xla_context.Exit()
def testSwitchCaseAndTensorArrayInDefun(self):
self.skipTest("b/127846988")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@def_function.function
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_switch_case.switch_case(
constant_op.constant(1), {
0: lambda: ta.write(0, 5.),
1: lambda: ta.write(0, 10.),
2: lambda: ta.write(0, 15.),
})
return output.stack()
output_t = f()
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testSwitchCaseAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_switch_case.switch_case(
constant_op.constant(1), {
0: lambda: ta.write(0, 5.),
1: lambda: ta.write(0, 10.),
2: lambda: ta.write(0, 15.),
})
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testSwitchCaseConstPropagation(self):
self.skipTest("b/127846988")
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = array_ops.placeholder(dtypes.int32)
def branch0():
return 5.
def branch1():
return 15.
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def branch2():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[p]
output = control_flow_switch_case.switch_case(
constant_op.constant(2), {
0: branch0,
1: branch1,
2: branch2,
})
self.assertAllEqual(7.,
sess.run(output, feed_dict={
x: [0., 1., 7.],
p: 2,
}))
xla_context.Exit()
def testCondNoInputs(self):
"""Verifies against `Failed precondition: Expected one input shape`."""
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
for pred in True, False:
cond_out = cond.cond(
array_ops.placeholder_with_default(pred, []),
lambda: constant_op.constant(2.),
lambda: constant_op.constant(1.))
self.assertEqual(int(pred) + 1., self.evaluate(cond_out))
xla_context.Exit()
if __name__ == '__main__':
test.main()