blob: 67bab1fe329331b4a264f11ece5321b38218e4ae [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA call module op wrapper."""
from typing import Tuple
import unittest
import numpy as np
from tensorflow.compiler.mlir.stablehlo import stablehlo
from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
def serialize(module_str: str) -> Tuple[str, int]:
target = stablehlo.get_minimum_version()
byte_str = stablehlo.serialize_portable_artifact(module_str, target)
return byte_str, 4
class XlaCallModuleOpTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self,
op,
args,
expected,
equality_fn=None):
"""Asserts op(*args) == expected."""
with self.session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
for arg in args
]
feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
output = op(*placeholders)
result = session.run(output, feeds)
if not equality_fn:
equality_fn = self.assertAllClose
equality_fn(result, expected, rtol=1e-3)
def testing_platform(self):
"""Current testing platform, one of CPU, GPU, TPU."""
if self.device in ['CPU', 'XLA_CPU']:
return 'CPU'
elif self.device in ['GPU', 'XLA_GPU']:
return 'CUDA'
elif self.device in ['TPU', 'XLA_TPU']:
return 'TPU'
else:
assert False, f'Unexpected {self.device=}'
def test_basic(self):
x = np.array([1., 2., 3.], dtype=np.float32)
def f(x):
# sin(cos(x))
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.cosine %arg0 : tensor<3xf32>
%1 = stablehlo.sine %0 : tensor<3xf32>
return %1 : tensor<3xf32>
}
}
""")
return xla.call_module([x], version=version,
module=module, Tout=[x.dtype], Sout=[x.shape])
self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),))
def test_basic_with_token(self):
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
def f(x):
# sin(cos(x))
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: !stablehlo.token, %arg1: tensor<3xf32>) -> (!stablehlo.token, tensor<3xf32>) {
%0 = stablehlo.cosine %arg1 : tensor<3xf32>
%1 = stablehlo.sine %0 : tensor<3xf32>
return %arg0, %1 : !stablehlo.token, tensor<3xf32>
}
}
""")
return xla.call_module(
[x],
version=version,
module=module,
Tout=[x.dtype],
Sout=[x.shape],
has_token_input_output=True,
)
self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),))
def test_compare(self):
x = np.uint32(2)
res = np.bool_(True)
def f(x):
# return x >= 1
module, version = serialize("""
module @jit_f_jax.0 {
func.func public @main(%arg0: tensor<ui32>) -> tensor<i1> {
%0 = stablehlo.constant dense<1> : tensor<ui32>
%1 = "stablehlo.compare"(%arg0, %0) {compare_type = #stablehlo<comparison_type UNSIGNED>, comparison_direction = #stablehlo<comparison_direction GE>} : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
return %1 : tensor<i1>
}
}
""")
return xla.call_module([x], version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_multiple_args_results(self):
x = np.array([1., 2., 3.], dtype=np.float32)
y = np.array([11., 12., 13., 14.], dtype=np.float64)
def f(x, y):
# (sin(x), cos(y))
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<3xf32>, %arg1: tensor<4xf64>) -> (tensor<3xf32>, tensor<4xf64>) {
%0 = stablehlo.sine %arg0 : tensor<3xf32>
%1 = stablehlo.cosine %arg1 : tensor<4xf64>
return %0, %1 : tensor<3xf32>, tensor<4xf64>
}
}
""")
return xla.call_module([x, y], version=version,
module=module,
Tout=[x.dtype, y.dtype],
Sout=[x.shape, y.shape])
self._assertOpOutputMatchesExpected(f, (x, y), (np.sin(x), np.cos(y)))
# TODO(b/283439649): remove dim_args_spec support
def test_dim_var_basic(self):
x = np.arange(6, dtype=np.float32).reshape((2, 3))
def f(x): # x: f32[2, b]
# Module takes another argument which is the value of b
# (sin(x), x.shape[1])
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>) {
%0 = stablehlo.sine %arg1 : tensor<2x?xf32>
return %0, %arg0 : tensor<2x?xf32>, tensor<i32>
}
}
""")
return gen_xla_ops.xla_call_module(
[x],
version=version,
module=module,
Tout=[x.dtype, np.int32],
Sout=[(None, 3), ()],
dim_args_spec=['0.1'])
self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1]))
# TODO(b/283439649): remove dim_args_spec support
def test_dim_var_basic_dim_arg_i64(self):
x = np.arange(6, dtype=np.float32).reshape((2, 3))
def f(x): # x: f32[2, b]
# Module takes another argument which is the value of b
# (sin(x), x.shape[1])
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<i64>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i64>) {
%0 = stablehlo.sine %arg1 : tensor<2x?xf32>
return %0, %arg0 : tensor<2x?xf32>, tensor<i64>
}
}
""")
return gen_xla_ops.xla_call_module(
[x],
module=module, version=version,
Tout=[x.dtype, np.int64],
Sout=[(None, 3), ()],
dim_args_spec=['0.1'])
self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1]))
def test_dim_var_basic_wrapped(self):
"""Like dim_arg_var_basic, but with the wrapper already added."""
x = np.arange(6, dtype=np.float32).reshape((2, 3))
def f(x): # x: f32[2, b]
# (sin(x), x.shape[1])
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>) {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 1 : i64} : (tensor<2x?xf32>) -> tensor<i32>
%0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>)
return %0, %1 : tensor<2x?xf32>, tensor<i32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>) {
%0 = stablehlo.sine %arg1 : tensor<2x?xf32>
return %0, %arg0 : tensor<2x?xf32>, tensor<i32>
}
}
""")
return xla.call_module([x],
module=module, version=version,
Tout=[x.dtype, np.int32],
Sout=[(None, 3), ()])
self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1]))
def test_wrong_actual_args_errors(self):
x = np.arange(6, dtype=np.float32).reshape((3, 2))
y = np.arange(6, dtype=np.int32).reshape((2, 3))
# x: f32[a, 2], return x
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<?x2xf32>, %arg1: tensor<*xi32>) -> tensor<?x2xf32> {
return %arg0 : tensor<?x2xf32>
}
}
""")
def f(x, y):
return xla.call_module(
[x, y],
module=module,
version=version,
Tout=[x.dtype],
Sout=[(None, 2)],
)
self._assertOpOutputMatchesExpected(f, (x, y), (x,))
x_bad_etype = x.astype(np.int32)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Element type mismatch for argument 0 passed to XlaCallModule: '
r'expecting tensor<\?x2xf32>, got tensor<3x2xi32>',
):
self._assertOpOutputMatchesExpected(f, (x_bad_etype, y), (x_bad_etype,))
y_bad_etype = y.astype(np.float32)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Element type mismatch for argument 1 passed to XlaCallModule: '
r'expecting tensor<\*xi32>, got tensor<2x3xf32>',
):
self._assertOpOutputMatchesExpected(f, (x, y_bad_etype), (x,))
x_bad_shape = np.arange(15, dtype=np.float32).reshape(5, 3)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Shape mismatch for argument 0 passed to XlaCallModule: '
r'expecting tensor<\?x2xf32>, got tensor<5x3xf32>',
):
self._assertOpOutputMatchesExpected(f, (x_bad_shape, y), (x_bad_shape,))
def test_platforms_basic(self):
x = np.float32(0.)
# returns x + 2. on CPU, x + 3. on GPU and x + 4. on TPU
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<f32>) -> tensor<f32> {
%to_add = "stablehlo.case"(%arg_platform_idx) ({
%cpu_val = stablehlo.constant dense<2.> : tensor<f32>
stablehlo.return %cpu_val : tensor<f32>
}, {
%gpu_val = stablehlo.constant dense<3.> : tensor<f32>
stablehlo.return %gpu_val : tensor<f32>
}, {
%tpu_val = stablehlo.constant dense<4.> : tensor<f32>
stablehlo.return %tpu_val : tensor<f32>
}) : (tensor<i32>) -> tensor<f32>
%0 = stablehlo.add %arg0, %to_add : tensor<f32>
return %0 : tensor<f32>
}
}
""")
platforms = ['CPU', 'CUDA', 'TPU']
def f(x):
return xla.call_module([x], version=version,
module=module,
Tout=[np.float32],
Sout=[()],
platforms=platforms)
expected_value = x + dict(CPU=2., CUDA=3., TPU=4.)[self.testing_platform()]
self._assertOpOutputMatchesExpected(f, (x,), (expected_value,))
def test_platforms_errors(self):
"""Error reporting for the platforms attribute."""
x = np.float32(0.)
module_str = """
module @jit_f.0 {
func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
}
"""
module, version = serialize(module_str)
platforms = []
def f(x):
return xla.call_module([x], version=version,
module=module,
Tout=[np.float32],
Sout=[()],
platforms=platforms)
# With empty platforms, there should be no platform_index argument
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Incorrect number of arguments passed to XlaCallModule: 1. '
'The module takes 2 arguments of which 0 platform index arguments '
'and 0 dimension arguments.'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
# Same with a single platform
platforms = ['CPU']
if self.testing_platform() == 'CPU':
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Incorrect number of arguments passed to XlaCallModule: 1. '
'The module takes 2 arguments of which 0 platform index arguments '
'and 0 dimension arguments.'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
platforms = ['RANDOM_PLATFORM_1', 'RANDOM_PLATFORM_2']
with self.assertRaisesRegex(
errors.NotFoundError,
'The current platform .* is not among the platforms'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
platforms = ['CPU', 'CUDA']
if self.testing_platform() not in platforms:
with self.assertRaisesRegex(
errors.NotFoundError,
'The current platform .* is not among the platforms'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
else:
self._assertOpOutputMatchesExpected(f, (x,), (x,))
# The module cannot have i64 %arg_platform_idx
module, version = serialize(module_str.replace('i32', 'i64'))
platforms = ['CPU', 'CUDA', 'TPU']
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Module argument at index 0 should be a 0-dimensional '
'32-bit integer-tensor platform index argument .* has type '
'tensor<i64>'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
# A module without the platform index argument
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
return %arg0 : tensor<i32>
}
}
""")
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'The module should have 1 platform index arguments and 0 dimension '
'arguments, but it has only 1 total arguments'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
def test_dynamic_iota(self):
x = np.ones((3, 5), dtype=np.int32)
res = np.arange(x.shape[0], dtype=np.int32)
def f(x): # x: f32[b, 5]
# return np.arange(x.shape[0], dtype=np.int32)
module, version = serialize("""
module @jit_fun.1 {
func.func public @main(%arg1: tensor<?x5xi32>) -> tensor<?xi32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?x5xi32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x5xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x5xi32>) -> tensor<?xi32> {
%0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
}
""")
return xla.call_module([x,], version=version,
module=module,
Tout=[res.dtype],
Sout=[(None,)])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_build_graph_with_any_platform(self):
"""We can construct the tf.Graph on all platforms."""
x = np.float32(0.)
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
}
""")
platforms = ['TPU'] # the module is compileable only on TPU
def f(x):
return xla.call_module([x], version=version,
module=module,
Tout=[np.float32],
Sout=[()],
platforms=platforms)
tf_graph = def_function.function(f).get_concrete_function(x).graph
self.assertIn('XlaCallModule', str(tf_graph.as_graph_def()))
def test_dynamic_reshape(self):
x = np.ones((4, 3), dtype=np.float32)
res = x.reshape((-1,))
def f(x): # x: f32[b, 3]
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg1: tensor<?x3xf32>) -> tensor<?xf32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?x3xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x3xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x3xf32>) -> tensor<?xf32> {
%0 = stablehlo.constant dense<3> : tensor<i32>
%1 = stablehlo.multiply %arg0, %0 : tensor<i32>
%2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.dynamic_reshape %arg1, %2 : (tensor<?x3xf32>, tensor<1xi32>) -> tensor<?xf32>
return %3 : tensor<?xf32>
}
}
""")
return xla.call_module([x], version=version,
module=module,
Tout=[res.dtype],
Sout=[(None,)])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_dynamic_gather(self):
x = np.ones((3, 4), dtype=np.float32)
res = np.ones((3, 2), dtype=np.float32)
def f(x): # x: f32[b, 4]
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg1: tensor<?x4xf32>) -> tensor<?x2xf32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?x4xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x4xf32>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>) -> tensor<?x2xf32> {
%0 = stablehlo.constant dense<0> : tensor<i64>
%1 = stablehlo.constant dense<0> : tensor<1xi64>
%2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.constant dense<2> : tensor<1xi32>
%4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%5 = "stablehlo.dynamic_gather"(%arg1, %1, %4) {dimension_numbers = #stablehlo.gather<offset_dims = [0, 1], start_index_map = [1]>, indices_are_sorted = true} : (tensor<?x4xf32>, tensor<1xi64>, tensor<2xi32>) -> tensor<?x2xf32>
return %5 : tensor<?x2xf32>
}
}
""")
return xla.call_module([x], version=version,
module=module,
Tout=[res.dtype],
Sout=[(None, 2)])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_real_dynamic_slice(self):
x = np.ones((3, 4), dtype=np.float32)
res = x[-1, :]
def f(x): # x: f32[b, 4]
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg1: tensor<?x4xf32>) -> tensor<4xf32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?x4xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>) -> tensor<4xf32> {
%0 = stablehlo.constant dense<-1> : tensor<i32>
%1 = stablehlo.add %arg0, %0 : tensor<i32>
%2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.constant dense<0> : tensor<1xi32>
%4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%5 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%6 = stablehlo.constant dense<4> : tensor<1xi32>
%7 = stablehlo.concatenate %5, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%10 = stablehlo.constant dense<1> : tensor<2xi32>
%11 = stablehlo.real_dynamic_slice %arg1, %4, %7, %10 : (tensor<?x4xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32>
%12 = stablehlo.reshape %11 : (tensor<1x4xf32>) -> tensor<4xf32>
return %12 : tensor<4xf32>
}
}
""")
return xla.call_module([x], version=version,
module=module,
Tout=[x.dtype],
Sout=[(4,)])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_dynamic_update_slice(self):
x = np.ones((3, 4), dtype=np.float32)
idx = np.int32(-2)
res = x # The update should be a nop
def f(x, idx): # x: f32[b, 4] idx: i32
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg1: tensor<?x4xf32>, %arg2: tensor<i32>) -> tensor<?x4xf32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?x4xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1, %arg2) : (tensor<i32>, tensor<?x4xf32>, tensor<i32>) -> tensor<?x4xf32>
return %0 : tensor<?x4xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>, %arg2: tensor<i32>) -> tensor<?x4xf32> {
%0 = stablehlo.constant dense<0> : tensor<i32>
%1 = stablehlo.compare LT, %arg2, %0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = stablehlo.add %arg2, %arg0 : tensor<i32>
%3 = stablehlo.select %1, %2, %arg2 : tensor<i1>, tensor<i32>
%4 = stablehlo.constant dense<0> : tensor<i32>
%5 = stablehlo.dynamic_update_slice %arg1, %arg1, %3, %4 : (tensor<?x4xf32>, tensor<?x4xf32>, tensor<i32>, tensor<i32>) -> tensor<?x4xf32>
return %5 : tensor<?x4xf32>
}
}
""")
return xla.call_module([x, idx], version=version,
module=module,
Tout=[res.dtype],
Sout=[(None, 4)])
self._assertOpOutputMatchesExpected(f, (x, idx), (res,))
def test_dynamic_broadcast_in_dim(self):
x = np.ones((3, 4), dtype=np.float32)
y = np.ones((2, 3, 4), dtype=np.float32)
res = (np.broadcast_to(x, y.shape), x + y)
def f(x, y): # x: f32[b, 4] y: f32[2, b, 4]
# return (np.broadcast_to(x, y.shape), x + y)
module, version = serialize("""
module @jit_fun.0 {
func.func public @main(%arg1: tensor<?x4xf32>, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) {
%arg0_new = "stablehlo.get_dimension_size"(%arg2) {dimension = 1 : i64} : (tensor<2x?x4xf32>) -> tensor<i32>
%0, %1 = call @dyn_main(%arg0_new, %arg1, %arg2) : (tensor<i32>, tensor<?x4xf32>, tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>)
return %0, %1 : tensor<2x?x4xf32>, tensor<2x?x4xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) {
%0 = stablehlo.constant dense<2> : tensor<1xi32>
%2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.constant dense<4> : tensor<1xi32>
%4 = "stablehlo.concatenate"(%0, %2, %3) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
%5 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x4xf32>, tensor<3xi32>) -> tensor<2x?x4xf32>
%6 = stablehlo.add %5, %arg2 : (tensor<2x?x4xf32>, tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
return %5, %6 : tensor<2x?x4xf32>, tensor<2x?x4xf32>
}
}
""")
return xla.call_module([x, y], version=version,
module=module,
Tout=[res[0].dtype, res[1].dtype],
Sout=[(2, None, 4), (2, None, 4)])
self._assertOpOutputMatchesExpected(f, (x, y), res)
@unittest.skip('TODO(necula): test is flaky')
def test_reduce(self):
x = np.arange(5, dtype=np.int32)
res = np.sum(x) * x.shape[0]
def f(x): # x: i32[b]
module, version = serialize("""
module @jit_fun{
func.func public @main(%arg1: tensor<?xi32>) -> tensor<i32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg2) {dimension = 0 : i64} : (tensor<?xi32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xi32>) -> tensor<i32>
return %0 : tensor<i32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xi32>) -> tensor<i32> {
%0 = stablehlo.constant dense<0> : tensor<i32>
%1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [0] : (tensor<?xi32>, tensor<i32>) -> tensor<i32>
reducer(%arg2: tensor<i32>, %arg3: tensor<i32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<i32>
"stablehlo.return"(%4) : (tensor<i32>) -> ()
}
%2 = stablehlo.multiply %1, %arg0 : tensor<i32>
return %2 : tensor<i32>
}
}
""")
return xla.call_module([x], version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_reduce_broadcast(self):
x = np.broadcast_to(np.arange(3, dtype=np.float32).reshape(3, 1), (3, 5))
res = np.arange(3, dtype=np.float32).reshape(3, 1) * 5
def f(x): # x: f32[b, 5]
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg1: tensor<?x5xf32>) -> tensor<?x1xf32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?x5xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x5xf32>) -> tensor<?x1xf32>
return %0 : tensor<?x1xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x5xf32>) -> tensor<?x1xf32> {
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [1] : (tensor<?x5xf32>, tensor<f32>) -> tensor<?xf32>
reducer(%arg2: tensor<f32>, %arg3: tensor<f32>) {
%6 = stablehlo.add %arg2, %arg3 : tensor<f32>
stablehlo.return %6 : tensor<f32>
}
%2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.constant dense<1> : tensor<1xi32>
%4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [0] : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x1xf32>
return %5 : tensor<?x1xf32>
}
}
""")
return xla.call_module([x,], version=version,
module=module,
Tout=[res.dtype],
Sout=[(None, 1)])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_call(self):
"""A chain of calls."""
x = np.ones((5,), dtype=np.float32)
res = np.arange(x.shape[0], dtype=np.int32)
def f(x): # x: f32[b]
module, version = serialize("""
module @jit_fun_3 {
func.func public @main(%arg1: tensor<?xf32>) -> tensor<?xi32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xi32> {
%0 = call @f(%arg0, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
func.func private @f(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xi32> {
%0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
}
""")
return xla.call_module([x,], version=version,
module=module,
Tout=[res.dtype],
Sout=[()])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_identity(self):
x = np.ones((5,), dtype=np.float32)
res = x
def f(x): # x: f32[b]
module, version = serialize("""
module @jit_fun_3 {
func.func public @main(%arg1: tensor<?xf32>) -> tensor<?xf32> {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}
""")
return xla.call_module([x], version=version,
module=module,
Tout=[res.dtype],
Sout=[()])
self._assertOpOutputMatchesExpected(f, (x,), (res,))
def test_while(self):
"""A while loop with carryied dynamic shapes."""
x = np.ones((5,), dtype=np.float32)
# Compute the result in Pyton first
res0 = np.copy(x)
for _ in range(5):
res0 += np.arange(x.shape[0], dtype=np.float32)
res1 = np.int64(5)
def f(x): # x: f32[b]
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg1: tensor<?xf32>) -> (tensor<?xf32>, tensor<i64>) {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<i64>)
return %0, %1 : tensor<?xf32>, tensor<i64>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>, tensor<i64>) {
%0 = stablehlo.constant dense<0> : tensor<i64>
%1:2 = "stablehlo.while"(%arg1, %0) ({
^bb0(%arg2: tensor<?xf32>, %arg3: tensor<i64>):
%2 = stablehlo.constant dense<5> : tensor<i64>
%3 = stablehlo.compare LT, %arg3, %2, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %3 : tensor<i1>
}, {
^bb0(%arg2: tensor<?xf32>, %arg3: tensor<i64>):
%2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.dynamic_iota %2, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.add %arg2, %3 : tensor<?xf32>
%5 = stablehlo.constant dense<1> : tensor<i64>
%6 = stablehlo.add %arg3, %5 : tensor<i64>
stablehlo.return %4, %6 : tensor<?xf32>, tensor<i64>
}) : (tensor<?xf32>, tensor<i64>) -> (tensor<?xf32>, tensor<i64>)
return %1#0, %1#1 : tensor<?xf32>, tensor<i64>
}
}
""")
return xla.call_module([x,], version=version,
module=module,
Tout=[res0.dtype, res1.dtype],
Sout=[(None,), res1.shape])
self._assertOpOutputMatchesExpected(f, (x,), (res0, res1))
def test_tf_call_function(self):
"""A TensorFlow function call inside StableHLO."""
x = np.int32(2)
y = np.int32(3)
res = x + y
@function.Defun(dtypes.int32, dtypes.int32)
def foo(x, y):
return x + y
def f(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 0}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape],
function_list=(foo,),
)
self._assertOpOutputMatchesExpected(f, (x, y), (res,))
def test_tf_call_function_multiple_funcs(self):
"""Multiple TensorFlow function calls inside StableHLO."""
x = np.int32(2)
y = np.int32(3)
res = (x + y) + (x + y)
@function.Defun(dtypes.int32, dtypes.int32)
def foo(x, y):
return x + y
@function.Defun(dtypes.int32, dtypes.int32)
def bar(x, y):
return foo(x, y)
def f(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 0}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%1 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 1}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2 = stablehlo.custom_call @tf.call_tf_function(%0, %1) {
tf.backend_config = {called_index = 1}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %2 : tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape],
function_list=(foo, bar),
)
self._assertOpOutputMatchesExpected(f, (x, y), (res,))
def test_shape_polymorphic_tf_call_function(self):
"""A TensorFlow function call inside StableHLO."""
x = np.full((2,), 2, dtype=np.int32)
y = np.full((2,), 3, dtype=np.int32)
res = x + y
@function.Defun(dtypes.int32, dtypes.int32)
def foo(x, y):
return x + y
def f(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xi32>) -> tensor<i32>
%1 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1, %0) {
tf.backend_config = {called_index = 0},
indices_of_shape_operands = dense<[2]> : tensor<1xi64>
} : (tensor<?xi32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape],
function_list=(foo,),
)
self._assertOpOutputMatchesExpected(f, (x, y), (res,))
def test_tf_call_function_with_token(self):
"""A TensorFlow function call inside StableHLO."""
x = np.int32(2)
y = np.int32(3)
res = x + y
@function.Defun(dtypes.int32, dtypes.int32)
def foo(x, y):
return x + y
def f(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: !stablehlo.token, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (!stablehlo.token, tensor<i32>) {
%0:2 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1, %arg2) {
tf.backend_config = {called_index = 0, has_token_input_output = true}
} : (!stablehlo.token, tensor<i32>, tensor<i32>) -> (!stablehlo.token, tensor<i32>)
return %0#0, %0#1 : !stablehlo.token, tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape],
function_list=(foo,),
has_token_input_output=True,
)
self._assertOpOutputMatchesExpected(f, (x, y), (res,))
def test_tf_call_function_nested(self):
"""Nested XlaCallModule inside TensorFlow function calls."""
x = np.int32(2)
y = np.int32(3)
res = x + y
@function.Defun(dtypes.int32, dtypes.int32)
def add(x, y):
return x + y
@function.Defun(dtypes.int32, dtypes.int32)
def nested_xla_call(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 0}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape],
function_list=(add,),
)
@function.Defun(dtypes.int32, dtypes.int32)
def call(x, y):
return nested_xla_call(x, y)
def f(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 0}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res.dtype],
Sout=[res.shape],
function_list=(call,),
)
self._assertOpOutputMatchesExpected(f, (x, y), (res,))
def test_tf_call_function_nested_func_renaming(self):
"""Multiple custom calls with identically named private functions."""
x = np.int32(2)
y = np.int32(3)
res0 = x + y
res1 = x - y
# Verify that multiple inner TF function calls with the same private
# functions are properly renamed during MHLO import. This test case is
# carefully constructed such that one outer XlaCallModule op has two custom
# calls, each of which has the same private "@call" function with different
# body. This is to catch bugs in the func renaming logic.
@function.Defun(dtypes.int32, dtypes.int32)
def add(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func private @call(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32>
return %0 : tensor<i32>
}
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = func.call @call(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res0.dtype],
Sout=[res0.shape],
)
@function.Defun(dtypes.int32, dtypes.int32)
def subtract(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func private @call(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.subtract %arg0, %arg1 : tensor<i32>
return %0 : tensor<i32>
}
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = func.call @call(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res1.dtype],
Sout=[res1.shape],
)
def f(x, y):
module, version = serialize("""
module @jit_fun_flat_jax {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 0}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%1 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {
tf.backend_config = {called_index = 1}
} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %0, %1 : tensor<i32>, tensor<i32>
}
}
""")
return xla.call_module(
[x, y],
version=version,
module=module,
Tout=[res0.dtype, res1.dtype],
Sout=[res0.shape, res1.shape],
function_list=(add, subtract),
)
self._assertOpOutputMatchesExpected(f, (x, y), (res0, res1))
def test_op_backward_compatibility(self):
"""Test for ensuring XlaCallModuleOp backward compatiblity."""
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
def f(x):
# sin(cos(x))
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.cosine %arg0 : tensor<3xf32>
%1 = stablehlo.sine %0 : tensor<3xf32>
return %1 : tensor<3xf32>
}
}
""")
# Create the raw XlaCallModule op directly instead of calling
# `xla.call_module`, which handles default values for unpresent
# attributes.
return gen_xla_ops.xla_call_module(
[x],
version=version,
module=module,
Tout=[x.dtype],
Sout=[x.shape],
)
self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),))
if __name__ == '__main__':
# This test is using Tensorflow sessions which are not compatible with eager
# mode.
ops.disable_eager_execution()
googletest.main()