blob: 3c9b29b1835c794152719a4b26f9e6cb2ef80f95 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ops which manipulate lists of tensors via bridge."""
# pylint: disable=g-bad-name
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase):
def testElementShape(self):
with self.session() as sess, self.test_scope():
dim = array_ops.placeholder(dtypes.int32)
l = list_ops.empty_tensor_list(
element_shape=(dim, 15),
element_dtype=dtypes.float32,
max_num_elements=20)
e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
def testPushPop(self):
with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=(7, 15),
element_dtype=dtypes.float32,
max_num_elements=10)
l = list_ops.tensor_list_push_back(
l, constant_op.constant(1.0, shape=(7, 15)))
l = list_ops.tensor_list_push_back(
l, constant_op.constant(2.0, shape=(7, 15)))
l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
_, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15)))
self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15)))
def testDoNotConstantFoldVariants(self):
with self.session() as sess, self.test_scope():
val = array_ops.placeholder(dtype=dtypes.float32)
l = list_ops.empty_tensor_list(
element_shape=(7, 15),
element_dtype=dtypes.float32,
max_num_elements=10)
# Note: Pushing a Placeholder will force the constant folding code
# to build a Const node with a DT_VARIANT output. This tests that XLA
# passes a cf_consider_fn which prevent folding such nodes.
l = list_ops.tensor_list_push_back(
l, array_ops.fill(value=val, dims=(7, 15)))
l = list_ops.tensor_list_push_back(
l, constant_op.constant(2.0, shape=(7, 15)))
l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
_, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15)))
self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15)))
def testPushPopSeparateLists(self):
with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=[],
element_dtype=dtypes.float32,
max_num_elements=20)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
_, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
result = sess.run([e11, [e21, e22], [e31, e32]])
self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def testEmptyTensorListNoMax(self):
with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=(7, 15), element_dtype=dtypes.float32)
l = list_ops.tensor_list_push_back(
l, constant_op.constant(1.0, shape=(7, 15)))
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Set the max number of elements"):
self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
def testEmptyTensorListMax(self):
with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=(10, 15), element_dtype=dtypes.float32,
max_num_elements=2)
l = list_ops.tensor_list_push_back(
l, array_ops.fill(value=3.0, dims=(10, 15)))
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15)))
def testListFromTensor(self):
with self.session(), self.test_scope():
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e, 1.0)
l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 2.0)
l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(e1, 1.0)
self.assertAllEqual(list_ops.tensor_list_length(l), 2)
def testGetSet(self):
with self.session(), self.test_scope():
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 2.0])
def testSetDoesNotUpdatePushIndex(self):
with self.session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
# SetItem should not change the push index.
l = list_ops.tensor_list_set_item(l, 1, 3.)
l = list_ops.tensor_list_push_back(l, 5.)
l = list_ops.tensor_list_push_back(l, 7.)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [5., 7.])
def testGetSetReserved(self):
with self.session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=2)
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e0, 0.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 0.0])
def testSetStackReservedUnknownElementShape(self):
with self.session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=None, num_elements=2)
l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
def testPushInEmptyListWithUnknownElementShape(self):
with self.session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
# Pushing an element with a different shape should raise an error.
with self.assertRaisesRegex(errors.InternalError, "shape"):
l = list_ops.tensor_list_push_back(l, 5.)
self.evaluate(
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
def testGetSetReservedNonScalar(self):
with self.session() as sess, self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32,
element_shape=(7, 15),
num_elements=2)
l = list_ops.tensor_list_set_item(
l, 0, constant_op.constant(1.0, shape=(7, 15)))
e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e1), np.ones((7, 15)))
self.assertAllEqual(sess.run(e2), np.zeros((7, 15)))
def testStack(self):
with self.session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=[],
max_num_elements=2)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
self.assertAllEqual(e, 1.0)
l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t.shape.as_list(), [None])
self.assertAllEqual(t, [1.0, 2.0])
@parameterized.named_parameters(
("FlatList", [1.0, 2.0, 3.0], [], [0, 2], [1.0, 3.0]),
("NestedList", [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]
], [2], [1], [[3.0, 4.0]]),
("EmptyIndices", [1.0, 2.0, 3.0], [], [], []),
)
def testGather(self, input_list, element_shape, indices, output):
with self.session(), self.test_scope():
tensor_list = list_ops.tensor_list_from_tensor(
input_list, element_shape=element_shape)
gather_t = list_ops.tensor_list_gather(
tensor_list, indices, element_dtype=dtypes.float32)
self.assertAllEqual(gather_t, output)
def testStackWithUninitializedTensors(self):
with self.session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=3)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [0., 0., 0.])
def testZerosLikeForTensorList(self):
with self.session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=[],
max_num_elements=2)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
z = array_ops.zeros_like(l)
z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32)
self.assertAllEqual(z.shape.as_list(), [None])
self.assertAllEqual(z, [0.0, 0.0])
def testInvalidSplitLength(self):
with self.session(), self.test_scope():
tensor_list_split = list_ops.tensor_list_split(
tensor=[1], element_shape=[-1], lengths=[0]
)
with self.assertRaisesRegex(
errors.UnimplementedError, "All lengths must be positive"
):
self.evaluate(tensor_list_split)
if __name__ == "__main__":
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
os.environ.get("TF_XLA_FLAGS", ""))
test.main()