| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for reading and writing variables.""" |
| import re |
| |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_state_ops |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import googletest |
| from tensorflow.python.training.gradient_descent import GradientDescentOptimizer |
| |
| |
| class VariableOpsTest(xla_test.XLATestCase): |
| """Test cases for resource variable operators.""" |
| |
| def testWriteEmptyShape(self): |
| # Verifies that we can pass an uninitialized variable with an empty shape, |
| # assign it a value, and successfully return it. |
| for dtype in self.numeric_types: |
| with self.session() as sess, self.test_scope(): |
| zeros = np.zeros([3, 0], dtype=dtype) |
| v = resource_variable_ops.ResourceVariable(zeros) |
| p = array_ops.placeholder(dtype) |
| x = v.assign(p) |
| with ops.control_dependencies([x]): |
| y = v.read_value() |
| self.assertAllClose(zeros, sess.run(y, {p: zeros})) |
| |
| def testOneWriteOneOutput(self): |
| # Regression test for a bug where computations with one non-constant |
| # output and one variable update were mishandled. |
| for dtype in self.numeric_types: |
| init = np.array([[1, 2j], [3, 4]]).astype(dtype) |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| sess.run(variables.variables_initializer([v])) |
| p = array_ops.placeholder(dtype) |
| x = v.assign_add(p) |
| with ops.control_dependencies([x]): |
| y = v.read_value() |
| self.assertAllClose( |
| np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), |
| sess.run(y, {p: [[1, 1], [1, 1]]})) |
| |
| def testSparseRead0DIndices(self): |
| for dtype in self.numeric_types: |
| init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10, |
| 11]]).astype(dtype) |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| sess.run(variables.variables_initializer([v])) |
| x = v.sparse_read(2) |
| self.assertAllClose( |
| np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x)) |
| |
| def testSparseRead1DIndices(self): |
| for dtype in self.numeric_types: |
| init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10, |
| 11]]).astype(dtype) |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| sess.run(variables.variables_initializer([v])) |
| x = v.sparse_read([2, 1]) |
| self.assertAllClose( |
| np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype), |
| self.evaluate(x)) |
| |
| def testSparseRead2DIndices(self): |
| for dtype in self.numeric_types: |
| init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10, |
| 11]]).astype(dtype) |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| sess.run(variables.variables_initializer([v])) |
| x = v.sparse_read([[2, 1], [0, 2]]) |
| self.assertAllClose( |
| np.array([[[8, 9, 10, 11], [4, 5, 6, 7]], |
| [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype), |
| self.evaluate(x)) |
| |
| def testSparseRead2DIndices3DTensor(self): |
| for dtype in self.numeric_types: |
| init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], |
| [[20, 21, 22], [23, 24j, 25]], |
| [[30, 31, 32], [33, 34, 35]]]).astype(dtype) |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| sess.run(variables.variables_initializer([v])) |
| x = v.sparse_read([[2, 1], [3, 0]]) |
| self.assertAllClose( |
| np.array( |
| [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], |
| [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] |
| ],).astype(dtype), self.evaluate(x)) |
| |
| def testShape(self): |
| for dtype in self.numeric_types: |
| init = np.ones([2, 3]).astype(dtype) |
| with self.session() as session, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| session.run(variables.variables_initializer([v])) |
| h = v.handle |
| s32, s64 = session.run([ |
| resource_variable_ops.variable_shape(h), |
| resource_variable_ops.variable_shape(h, out_type=dtypes.int64) |
| ]) |
| self.assertEqual(s32.dtype, np.int32) |
| self.assertEqual(s64.dtype, np.int64) |
| self.assertAllEqual(s32, [2, 3]) |
| self.assertAllEqual(s64, [2, 3]) |
| |
| def testInvalidShape(self): |
| pattern = re.compile("shapes must be equal", re.IGNORECASE) |
| # test invalid shape on assign_add in XLA |
| with self.assertRaisesRegex(Exception, pattern): |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) |
| sess.run(variables.variables_initializer([v])) |
| x = v.assign_add(1) |
| sess.run(x) |
| |
| # test invalid shape raised on assign_sub in XLA |
| with self.assertRaisesRegex(Exception, pattern): |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) |
| sess.run(variables.variables_initializer([v])) |
| x = v.assign_sub(1) |
| sess.run(x) |
| |
| def testReadWrite(self): |
| """Tests initialization, reading, and writing a resource variable.""" |
| for dtype in self.numeric_types: |
| with self.session() as session: |
| with self.test_scope(): |
| with variable_scope.variable_scope("ascope", use_resource=True): |
| x = variable_scope.get_variable( |
| "x", |
| shape=[], |
| dtype=dtype, |
| initializer=init_ops.constant_initializer(2)) |
| a = x.read_value() |
| with ops.control_dependencies([a]): |
| b = state_ops.assign(x, dtype(47)) |
| with ops.control_dependencies([b]): |
| c = x.read_value() |
| with ops.control_dependencies([c]): |
| d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype)) |
| with ops.control_dependencies([d]): |
| e = state_ops.assign_sub(x, dtype(3)) |
| with ops.control_dependencies([e]): |
| f = x.read_value() |
| |
| session.run(variables.global_variables_initializer()) |
| v1, v2, v3 = session.run([a, c, f]) |
| self.assertAllClose(dtype(2), v1) |
| self.assertAllClose(dtype(47), v2) |
| self.assertAllClose(np.array(50 + 2j).astype(dtype), v3) |
| |
| def testTraining(self): |
| """Tests a gradient descent step for a simple model.""" |
| with self.session() as session: |
| with self.test_scope(): |
| with variable_scope.variable_scope("ascope", use_resource=True): |
| w = variable_scope.get_variable( |
| "w", |
| shape=[4, 2], |
| dtype=dtypes.float32, |
| initializer=init_ops.constant_initializer( |
| np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32))) |
| b = variable_scope.get_variable( |
| "b", |
| shape=[2], |
| dtype=dtypes.float32, |
| initializer=init_ops.constant_initializer( |
| np.array([2, 3], dtype=np.float32))) |
| |
| x = array_ops.placeholder(dtypes.float32, shape=[1, 4]) |
| y = math_ops.matmul(x, w) + b |
| loss = math_ops.reduce_sum(y) |
| optimizer = GradientDescentOptimizer(0.1) |
| train = optimizer.minimize(loss) |
| |
| session.run(variables.global_variables_initializer()) |
| session.run(train, {x: np.array([[7, 3, 5, 9]], dtype=np.float32)}) |
| vw, vb = session.run([w, b]) |
| self.assertAllClose( |
| np.array( |
| [[0.3, 1.3], [2.7, 3.7], [4.5, 5.5], [6.1, 7.1]], |
| dtype=np.float32), |
| vw, |
| rtol=1e-4) |
| self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) |
| |
| def testWriteOfAliasedTensor(self): |
| for dtype in self.numeric_types: |
| init = np.array([[1, 2j], [3, 4]]).astype(dtype) |
| update = np.array([[7, 1j], [2, 11]]).astype(dtype) |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable(init) |
| sess.run(variables.variables_initializer([v])) |
| p = array_ops.placeholder(dtype) |
| q = array_ops.identity(p) |
| x = v.read_value() |
| # Writes the value of 'p' to 'v', but keeps a reference to the original |
| # value of 'v' so the variable update cannot reuse its buffer. |
| with ops.control_dependencies([x]): |
| y = v.assign(q) |
| result = sess.run([x, y, q], {p: update}) |
| self.assertAllClose(init, result[0]) |
| self.assertAllClose(update, result[1]) |
| self.assertAllClose(update, result[2]) |
| |
| def testScatterAdd(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[2, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[1], [7]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_add( |
| handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertAllEqual(self.evaluate(read), [[3], [7]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterSub") |
| def testScatterSub(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[2, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_sub( |
| handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertAllEqual(self.evaluate(read), [[4], [-1]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterMul") |
| def testScatterMul(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[1]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_mul( |
| handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[5]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterDiv") |
| def testScatterDiv(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_div( |
| handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertAllEqual(self.evaluate(read), [[2]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterMin") |
| def testScatterMin(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_min( |
| handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[3]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterMax") |
| def testScatterMax(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_max( |
| handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[6]]) |
| |
| def testScatterUpdate(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_update( |
| handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[3]]) |
| |
| def testScatterScalarUpdate(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_update( |
| handle, [0], constant_op.constant(3, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[3]]) |
| |
| def testScatterAddScalarUpdate(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[1]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_add( |
| handle, [0], constant_op.constant(2, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[3]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterSub") |
| def testScatterSubScalar(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[1]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_sub( |
| handle, [0], constant_op.constant(2, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[-1]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterMul") |
| def testScatterMulScalar(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[1]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_mul( |
| handle, [0], constant_op.constant(5, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[5]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterDiv") |
| def testScatterDivScalar(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_div( |
| handle, [0], constant_op.constant(3, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[2]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterMin") |
| def testScatterMinScalar(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_min( |
| handle, [0], constant_op.constant(3, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[3]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterMax") |
| def testScatterMaxScalar(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.int32, shape=[1, 1]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([[6]], dtype=dtypes.int32))) |
| sess.run( |
| resource_variable_ops.resource_scatter_max( |
| handle, [0], constant_op.constant(3, dtype=dtypes.int32))) |
| read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) |
| self.assertEqual(self.evaluate(read), [[6]]) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterNdAdd") |
| def testScatterNdAddOps(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.float32, shape=[8]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) |
| indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) |
| updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) |
| expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) |
| sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) |
| read = resource_variable_ops.read_variable_op( |
| handle, dtype=dtypes.float32) |
| self.assertAllClose(expected, self.evaluate(read)) |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceScatterNdUpdateAdd") |
| def testScatterNdUpdateAddOps(self): |
| with self.session() as sess, self.test_scope(): |
| handle = resource_variable_ops.var_handle_op( |
| dtype=dtypes.float32, shape=[8]) |
| sess.run( |
| resource_variable_ops.assign_variable_op( |
| handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) |
| indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) |
| updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) |
| expected = np.array([1, 11, 1, 10, 9, 1, 1, 12]) |
| sess.run( |
| gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) |
| read = resource_variable_ops.read_variable_op( |
| handle, dtype=dtypes.float32) |
| self.assertAllClose(expected, self.evaluate(read)) |
| |
| |
| class StridedSliceAssignChecker(object): |
| """Compares the results of a slice assignment using Tensorflow and numpy.""" |
| |
| def __init__(self, test, x, dtype): |
| self.dtype = dtype |
| self.test = test |
| self.x_np = np.array(x).astype(dtype) |
| # Randomly start on mode 0 or 1. |
| self.which_mode = np.random.randint(2, size=1)[0] |
| |
| def __setitem__(self, index, value): |
| self.which_mode = 1 - self.which_mode |
| value = np.array(value).astype(self.dtype) |
| |
| with self.test.session() as sess, self.test.test_scope(): |
| x = constant_op.constant(self.x_np, dtype=self.dtype) |
| var = resource_variable_ops.ResourceVariable(x) |
| sess.run(variables.variables_initializer([var])) |
| |
| if self.which_mode == 0: |
| val = sess.run(var[index].assign(value)) |
| else: |
| assert self.which_mode == 1 |
| val = sess.run(state_ops.assign(var[index], value)) |
| valnp = np.copy(self.x_np) |
| valnp[index] = np.array(value) |
| self.test.assertAllEqual(val, valnp) |
| |
| |
| class SliceAssignTest(xla_test.XLATestCase): |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support ResourceStridedSliceAssign") |
| def testSliceAssign(self): |
| for dtype in self.numeric_types: |
| checker = StridedSliceAssignChecker( |
| self, [[1, 2, 3], [4, 5, 6]], dtype=dtype) |
| # No-op assignment |
| checker[:] = [[10, 20, 30], [40, 50, 60]] |
| # Checks trivial (1,1) shape tensor |
| checker[1:2, 1:2] = [[66]] |
| # shrink shape changes |
| checker[1:2, 1] = [66] |
| checker[1, 1:2] = [66] |
| if dtype != dtypes.bfloat16.as_numpy_dtype: |
| # TODO(b/68813416): valnp call above results in an ndarray and not a |
| # number for bfloat16s. |
| checker[1, 1] = 66 |
| # newaxis shape changes |
| checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] |
| # shrink and newaxis |
| checker[None, None, 0, 0:1] = [[[99]]] |
| # Non unit strides |
| checker[::1, 1::-1] = [[3, 33], [4, 44]] |
| # degenerate interval |
| checker[8:10, 0] = [] |
| checker[8:10, 8:10] = [[]] |
| |
| # Assign vector to scalar (rank-0) using newaxis |
| checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) |
| if dtype != dtypes.bfloat16.as_numpy_dtype: |
| # TODO(b/68813416): valnp call above results in an ndarray and not a |
| # number for bfloat16s. |
| checker2[()] = 6 # no indices |
| checker2[...] = 6 # ellipsis |
| checker2[None] = [6] # new axis |
| |
| @test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet" |
| "support uninitialized resource variable") |
| def testUninitialized(self): |
| with self.assertRaisesRegex(errors.FailedPreconditionError, |
| "uninitialized"): |
| with self.session() as sess, self.test_scope(): |
| v = resource_variable_ops.ResourceVariable([1, 2]) |
| sess.run(v[:].assign([1, 2])) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |