blob: c16f95326d0f2517f4c24c82a05a609144c371d7 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for reading and writing variables."""
import re
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
class VariableOpsTest(xla_test.XLATestCase):
"""Test cases for resource variable operators."""
def testWriteEmptyShape(self):
# Verifies that we can pass an uninitialized variable with an empty shape,
# assign it a value, and successfully return it.
for dtype in self.numeric_types:
with self.session() as sess, self.test_scope():
zeros = np.zeros([3, 0], dtype=dtype)
v = resource_variable_ops.ResourceVariable(zeros)
p = array_ops.placeholder(dtype)
x = v.assign(p)
with ops.control_dependencies([x]):
y = v.read_value()
self.assertAllClose(zeros, sess.run(y, {p: zeros}))
def testOneWriteOneOutput(self):
# Regression test for a bug where computations with one non-constant
# output and one variable update were mishandled.
for dtype in self.numeric_types:
init = np.array([[1, 2j], [3, 4]]).astype(dtype)
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
p = array_ops.placeholder(dtype)
x = v.assign_add(p)
with ops.control_dependencies([x]):
y = v.read_value()
self.assertAllClose(
np.array([[2, 1 + 2j], [4, 5]]).astype(dtype),
sess.run(y, {p: [[1, 1], [1, 1]]}))
def testSparseRead0DIndices(self):
for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10,
11]]).astype(dtype)
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read(2)
self.assertAllClose(
np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x))
def testSparseRead1DIndices(self):
for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10,
11]]).astype(dtype)
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read([2, 1])
self.assertAllClose(
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
self.evaluate(x))
def testSparseRead2DIndices(self):
for dtype in self.numeric_types:
init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10,
11]]).astype(dtype)
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read([[2, 1], [0, 2]])
self.assertAllClose(
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
self.evaluate(x))
def testSparseRead2DIndices3DTensor(self):
for dtype in self.numeric_types:
init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
[[20, 21, 22], [23, 24j, 25]],
[[30, 31, 32], [33, 34, 35]]]).astype(dtype)
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read([[2, 1], [3, 0]])
self.assertAllClose(
np.array(
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
],).astype(dtype), self.evaluate(x))
def testShape(self):
for dtype in self.numeric_types:
init = np.ones([2, 3]).astype(dtype)
with self.session() as session, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
session.run(variables.variables_initializer([v]))
h = v.handle
s32, s64 = session.run([
resource_variable_ops.variable_shape(h),
resource_variable_ops.variable_shape(h, out_type=dtypes.int64)
])
self.assertEqual(s32.dtype, np.int32)
self.assertEqual(s64.dtype, np.int64)
self.assertAllEqual(s32, [2, 3])
self.assertAllEqual(s64, [2, 3])
def testInvalidShape(self):
pattern = re.compile("shapes must be equal", re.IGNORECASE)
# test invalid shape on assign_add in XLA
with self.assertRaisesRegex(Exception, pattern):
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
sess.run(variables.variables_initializer([v]))
x = v.assign_add(1)
sess.run(x)
# test invalid shape raised on assign_sub in XLA
with self.assertRaisesRegex(Exception, pattern):
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
sess.run(variables.variables_initializer([v]))
x = v.assign_sub(1)
sess.run(x)
def testReadWrite(self):
"""Tests initialization, reading, and writing a resource variable."""
for dtype in self.numeric_types:
with self.session() as session:
with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True):
x = variable_scope.get_variable(
"x",
shape=[],
dtype=dtype,
initializer=init_ops.constant_initializer(2))
a = x.read_value()
with ops.control_dependencies([a]):
b = state_ops.assign(x, dtype(47))
with ops.control_dependencies([b]):
c = x.read_value()
with ops.control_dependencies([c]):
d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype))
with ops.control_dependencies([d]):
e = state_ops.assign_sub(x, dtype(3))
with ops.control_dependencies([e]):
f = x.read_value()
session.run(variables.global_variables_initializer())
v1, v2, v3 = session.run([a, c, f])
self.assertAllClose(dtype(2), v1)
self.assertAllClose(dtype(47), v2)
self.assertAllClose(np.array(50 + 2j).astype(dtype), v3)
def testTraining(self):
"""Tests a gradient descent step for a simple model."""
with self.session() as session:
with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True):
w = variable_scope.get_variable(
"w",
shape=[4, 2],
dtype=dtypes.float32,
initializer=init_ops.constant_initializer(
np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32)))
b = variable_scope.get_variable(
"b",
shape=[2],
dtype=dtypes.float32,
initializer=init_ops.constant_initializer(
np.array([2, 3], dtype=np.float32)))
x = array_ops.placeholder(dtypes.float32, shape=[1, 4])
y = math_ops.matmul(x, w) + b
loss = math_ops.reduce_sum(y)
optimizer = GradientDescentOptimizer(0.1)
train = optimizer.minimize(loss)
session.run(variables.global_variables_initializer())
session.run(train, {x: np.array([[7, 3, 5, 9]], dtype=np.float32)})
vw, vb = session.run([w, b])
self.assertAllClose(
np.array(
[[0.3, 1.3], [2.7, 3.7], [4.5, 5.5], [6.1, 7.1]],
dtype=np.float32),
vw,
rtol=1e-4)
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
def testWriteOfAliasedTensor(self):
for dtype in self.numeric_types:
init = np.array([[1, 2j], [3, 4]]).astype(dtype)
update = np.array([[7, 1j], [2, 11]]).astype(dtype)
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
p = array_ops.placeholder(dtype)
q = array_ops.identity(p)
x = v.read_value()
# Writes the value of 'p' to 'v', but keeps a reference to the original
# value of 'v' so the variable update cannot reuse its buffer.
with ops.control_dependencies([x]):
y = v.assign(q)
result = sess.run([x, y, q], {p: update})
self.assertAllClose(init, result[0])
self.assertAllClose(update, result[1])
self.assertAllClose(update, result[2])
def testScatterAdd(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[2, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1], [7]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(read), [[3], [7]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterSub")
def testScatterSub(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[2, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[4], [1]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_sub(
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterMul")
def testScatterMul(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterDiv")
def testScatterDiv(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(read), [[2]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterMin")
def testScatterMin(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterMax")
def testScatterMax(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[6]])
def testScatterUpdate(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_update(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
def testScatterScalarUpdate(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_update(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
def testScatterAddScalarUpdate(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterSub")
def testScatterSubScalar(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_sub(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[-1]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterMul")
def testScatterMulScalar(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterDiv")
def testScatterDivScalar(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterMin")
def testScatterMinScalar(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterMax")
def testScatterMaxScalar(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
sess.run(
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[6]])
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterNdAdd")
def testScatterNdAddOps(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.float32, shape=[8])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
self.assertAllClose(expected, self.evaluate(read))
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceScatterNdUpdateAdd")
def testScatterNdUpdateAddOps(self):
with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.float32, shape=[8])
sess.run(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
expected = np.array([1, 11, 1, 10, 9, 1, 1, 12])
sess.run(
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
self.assertAllClose(expected, self.evaluate(read))
class StridedSliceAssignChecker(object):
"""Compares the results of a slice assignment using Tensorflow and numpy."""
def __init__(self, test, x, dtype):
self.dtype = dtype
self.test = test
self.x_np = np.array(x).astype(dtype)
# Randomly start on mode 0 or 1.
self.which_mode = np.random.randint(2, size=1)[0]
def __setitem__(self, index, value):
self.which_mode = 1 - self.which_mode
value = np.array(value).astype(self.dtype)
with self.test.session() as sess, self.test.test_scope():
x = constant_op.constant(self.x_np, dtype=self.dtype)
var = resource_variable_ops.ResourceVariable(x)
sess.run(variables.variables_initializer([var]))
if self.which_mode == 0:
val = sess.run(var[index].assign(value))
else:
assert self.which_mode == 1
val = sess.run(state_ops.assign(var[index], value))
valnp = np.copy(self.x_np)
valnp[index] = np.array(value)
self.test.assertAllEqual(val, valnp)
class SliceAssignTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support ResourceStridedSliceAssign")
def testSliceAssign(self):
for dtype in self.numeric_types:
checker = StridedSliceAssignChecker(
self, [[1, 2, 3], [4, 5, 6]], dtype=dtype)
# No-op assignment
checker[:] = [[10, 20, 30], [40, 50, 60]]
# Checks trivial (1,1) shape tensor
checker[1:2, 1:2] = [[66]]
# shrink shape changes
checker[1:2, 1] = [66]
checker[1, 1:2] = [66]
if dtype != dtypes.bfloat16.as_numpy_dtype:
# TODO(b/68813416): valnp call above results in an ndarray and not a
# number for bfloat16s.
checker[1, 1] = 66
# newaxis shape changes
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
# shrink and newaxis
checker[None, None, 0, 0:1] = [[[99]]]
# Non unit strides
checker[::1, 1::-1] = [[3, 33], [4, 44]]
# degenerate interval
checker[8:10, 0] = []
checker[8:10, 8:10] = [[]]
# Assign vector to scalar (rank-0) using newaxis
checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype)
if dtype != dtypes.bfloat16.as_numpy_dtype:
# TODO(b/68813416): valnp call above results in an ndarray and not a
# number for bfloat16s.
checker2[()] = 6 # no indices
checker2[...] = 6 # ellipsis
checker2[None] = [6] # new axis
@test_util.disable_mlir_bridge("TODO: MLIR bridge does not yet"
"support uninitialized resource variable")
def testUninitialized(self):
with self.assertRaisesRegex(errors.FailedPreconditionError,
"uninitialized"):
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable([1, 2])
sess.run(v[:].assign([1, 2]))
if __name__ == "__main__":
googletest.main()