| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Test cases for Tensorflow functions.""" |
| |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import function |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.platform import googletest |
| |
| |
| class FunctionTest(xla_test.XLATestCase): |
| |
| def testFunction(self): |
| """Executes a simple TensorFlow function.""" |
| |
| def APlus2B(a, b): |
| return a + b * 2 |
| |
| aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) |
| bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) |
| expected = APlus2B(aval, bval) |
| |
| with self.session(): |
| |
| @function.Defun(dtypes.float32, dtypes.float32) |
| def Foo(a, b): |
| return APlus2B(a, b) |
| |
| a = constant_op.constant(aval, name="a") |
| b = constant_op.constant(bval, name="b") |
| with self.test_scope(): |
| call_f = Foo(a, b) |
| result = self.evaluate(call_f) |
| self.assertAllClose(result, expected, rtol=1e-3) |
| |
| def testNestedFunctions(self): |
| """Executes two nested TensorFlow functions.""" |
| |
| def TimesTwo(x): |
| return x * 2 |
| |
| def APlus2B(a, b): |
| return a + TimesTwo(b) |
| |
| aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) |
| bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) |
| expected = APlus2B(aval, bval) |
| |
| with self.session(): |
| |
| @function.Defun(dtypes.float32, dtypes.float32) |
| def Foo(a, b): |
| return APlus2B(a, b) |
| |
| a = constant_op.constant(aval, name="a") |
| b = constant_op.constant(bval, name="b") |
| with self.test_scope(): |
| call_g = Foo(a, b) |
| result = self.evaluate(call_g) |
| self.assertAllClose(result, expected, rtol=1e-3) |
| |
| def testFunctionMultipleRetvals(self): |
| """Executes a function with multiple return values.""" |
| |
| # This function will run on the XLA device |
| def Func(a, b): |
| return a + b, a - b |
| |
| aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) |
| bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) |
| expected = Func(aval, bval) |
| |
| with self.session(): |
| |
| @function.Defun(dtypes.float32, dtypes.float32) |
| def Foo(a, b): |
| return Func(a, b) |
| |
| a = constant_op.constant(aval, name="a") |
| b = constant_op.constant(bval, name="b") |
| with self.test_scope(): |
| call_f = Foo(a, b) |
| result = self.evaluate(call_f) |
| self.assertAllClose(result, expected, rtol=1e-3) |
| |
| def testCompileTimeConstantsInDefun(self): |
| """Tests that XLA handles compile-time constants in defuns.""" |
| with self.session() as sess: |
| |
| @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) |
| def Foo(a, c, d): |
| # c and d must be known at compile time |
| x = array_ops.slice(a, c, d) |
| return x |
| |
| a = array_ops.placeholder(dtypes.float32) |
| c = array_ops.placeholder(dtypes.int32, shape=[4]) |
| d = array_ops.placeholder(dtypes.int32, shape=[4]) |
| with self.test_scope(): |
| call_f = Foo(a, c, d) |
| result = sess.run(call_f, feed_dict={ |
| a: np.ones([1, 4, 4, 1]), |
| c: [0, 0, 0, 0], |
| d: [1, 2, 2, 1]}) |
| |
| self.assertAllEqual(np.ones([1, 2, 2, 1]), result) |
| |
| # TODO(b/36139787): Re-enable this test when noinline works again. |
| def DISABLED_testFunctionsNoInline(self): |
| |
| @function.Defun(dtypes.float32, noinline=True) |
| def TimesTwo(x): |
| return x * 2 |
| |
| @function.Defun(dtypes.float32, dtypes.float32) |
| def APlus2B(a, b): |
| return a + TimesTwo(b) |
| |
| aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) |
| bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) |
| expected = aval + bval * 2 |
| |
| with self.session() as sess: |
| with self.test_scope(): |
| a = array_ops.placeholder(dtypes.float32, name="a") |
| b = array_ops.placeholder(dtypes.float32, name="b") |
| call = APlus2B(a, b) |
| result = sess.run(call, {a: aval, b: bval}) |
| self.assertAllClose(result, expected, rtol=1e-3) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |