blob: 72aa2fb7943f9154f4a83c9e57f3c4aca1c5d071 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for batch_norm related functionality in tensorflow.ops.nn."""
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class BatchNormalizationTest(test.TestCase):
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
y = (x - m) / np.sqrt(v + epsilon)
y = y * gamma if scale_after_normalization else y
return y + beta if shift_after_normalization else y
def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
y = (x - m) * math_ops.rsqrt(v + epsilon)
if scale_after_normalization:
y = gamma * y
return y + beta if shift_after_normalization else y
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
"""Original implementation."""
test_util.set_producer_version(ops.get_default_graph(), 8)
return gen_nn_ops._batch_norm_with_global_normalization(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
"""Re-implementation of the original kernel for backward compatibility."""
return nn_impl.batch_norm_with_global_normalization(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
"""New implementation."""
return nn_impl.batch_normalization(x, m, v, beta if
shift_after_normalization else None,
gamma if scale_after_normalization else
None, epsilon)
@test_util.run_deprecated_v1
def testBatchNorm(self):
x_shape = [3, 5, 4, 2]
param_shape = [2]
x_val = np.random.random_sample(x_shape).astype(np.float32)
m_val = np.random.random_sample(param_shape).astype(np.float32)
v_val = np.random.random_sample(param_shape).astype(np.float32)
beta_val = np.random.random_sample(param_shape).astype(np.float32)
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
for use_gpu in [True, False]:
with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name="x")
m = constant_op.constant(m_val, name="m")
v = constant_op.constant(v_val, name="v")
beta = constant_op.constant(beta_val, name="beta")
gamma = constant_op.constant(gamma_val, name="gamma")
epsilon = 0.001
for scale_after_normalization in [True, False]:
for shift_after_normalization in [True, False]:
bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
scale_after_normalization,
shift_after_normalization)
bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon,
scale_after_normalization)
bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
scale_after_normalization)
on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
scale_after_normalization,
shift_after_normalization)
np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val,
epsilon, scale_after_normalization,
shift_after_normalization)
tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
[bn2, bn1bw, bn1, on])
self.assertAllClose(np_bn, ops_bn, atol=0.00001)
self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001)
self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001)
# shift_after_normalization=False is not supported in v1.
if shift_after_normalization:
self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001)
self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001)
self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001)
self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001)
def _testBatchNormGradient(self,
param_index,
tag,
scale_after_normalization,
shift_after_normalization,
version,
err_tolerance=1e-11):
x_shape = [3, 5, 4, 5]
param_shape = [5]
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
m_val = np.random.random_sample(param_shape).astype(np.float64)
v_val = np.random.random_sample(param_shape).astype(np.float64)
beta_val = np.random.random_sample(param_shape).astype(np.float64)
gamma_val = np.random.random_sample(param_shape).astype(np.float64)
with self.cached_session():
x = constant_op.constant(x_val, name="x")
m = constant_op.constant(m_val, name="m")
v = constant_op.constant(v_val, name="v")
beta = constant_op.constant(beta_val, name="beta")
gamma = constant_op.constant(gamma_val, name="gamma")
epsilon = 0.001
if version == 1:
output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
scale_after_normalization)
elif version == 2:
output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
scale_after_normalization,
shift_after_normalization)
else:
print("Invalid version", version)
raise ValueError()
all_params = [x, m, v, beta, gamma]
all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
err = gradient_checker.compute_gradient_error(all_params[param_index],
all_shapes[param_index],
output, x_shape)
print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
(version, tag, "with" if scale_after_normalization else "without",
"with" if shift_after_normalization else "without"), err)
self.assertLess(err, err_tolerance)
def _testBatchNormGradientInAllNeedConfigs(self,
param_index,
tag,
err_tolerance=1e-11):
for scale_after_normalization in [True, False]:
for shift_after_normalization in [True, False]:
# shift_after_normalization=False is not supported in version 1.
for v in ([1, 2] if shift_after_normalization else [2]):
self._testBatchNormGradient(param_index, tag,
scale_after_normalization,
shift_after_normalization, v,
err_tolerance)
@test_util.run_deprecated_v1
def testBatchNormInputGradient(self):
self._testBatchNormGradientInAllNeedConfigs(0, "x")
@test_util.run_deprecated_v1
def testBatchNormMeanGradient(self):
self._testBatchNormGradientInAllNeedConfigs(1, "mean")
@test_util.run_deprecated_v1
def testBatchNormVarianceGradient(self):
self._testBatchNormGradientInAllNeedConfigs(
2, "variance", err_tolerance=1e-03)
@test_util.run_deprecated_v1
def testBatchNormBetaGradient(self):
# Since beta does not exist when scale_after_normalization=False, we only
# test for scale_after_normalization=True.
for scale_after_normalization in [True, False]:
for v in [1, 2]:
self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
v)
@test_util.run_deprecated_v1
def testBatchNormGammaGradient(self):
# If scale_after_normalization is False, backprop for gamma in v1
# will be 0. In version 2 of the API, if scale_after_normalization is False,
# gamma is not used at all, and the gradient is None, which displeases the
# gradient checker.
for scale_after_normalization in [True, False]:
self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
1)
for shift_after_normalization in [True, False]:
self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
2)
@test_util.run_deprecated_v1
def testBatchNormGradImpl(self):
x_shape = [7, 5, 4, 6]
param_shape = [6]
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float32)
m_val = np.random.random_sample(param_shape).astype(np.float32)
v_val = np.random.random_sample(param_shape).astype(np.float32)
beta_val = np.random.random_sample(param_shape).astype(np.float32)
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
backprop_val = np.random.random_sample(x_shape).astype(np.float32)
for use_gpu in [False, True]:
with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name="x")
m = constant_op.constant(m_val, name="m")
v = constant_op.constant(v_val, name="v")
beta = constant_op.constant(beta_val, name="beta")
gamma = constant_op.constant(gamma_val, name="gamma")
backprop = constant_op.constant(backprop_val, name="backprop")
epsilon = 0.001
for scale_after_normalization in [True, False]:
# _batch_norm_with_global_normalization_grad is deprecated in v9
test_util.set_producer_version(ops.get_default_graph(), 8)
grad = gen_nn_ops.batch_norm_with_global_normalization_grad(
x, m, v, gamma, backprop, epsilon, scale_after_normalization)
dx, dm, dv, db, dg = grad
self.assertEqual(grad.dx, dx)
self.assertEqual(grad.dm, dm)
self.assertEqual(grad.dv, dv)
self.assertEqual(grad.db, db)
self.assertEqual(grad.dg, dg)
on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
scale_after_normalization, True)
odx, odm, odv, odb, odg = gradients_impl.gradients(
[on], [x, m, v, beta, gamma], [backprop])
if scale_after_normalization:
all_grads = self.evaluate(
[dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
to_check = ["dx", "dm", "dv", "db", "dg"]
else:
all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb])
to_check = ["dx", "dm", "dv", "db"]
for i, _ in enumerate(to_check):
self.assertAllClose(
all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
@test_util.run_deprecated_v1
def testBatchNormKeepDims(self):
"""Test for tf.nn.moments(..., keep_dims=True / False).
Make sure that parameters with shape (1, 1, 1, depth) yield the same
result as parameters with shape (depth)
"""
x_shape = (3, 5, 4, 2)
param_shape = (2)
keep_dims_param_shape = (1, 1, 1, 2)
x_val = np.random.random_sample(x_shape).astype(np.float32)
m_val = np.random.random_sample(param_shape).astype(np.float32)
v_val = np.random.random_sample(param_shape).astype(np.float32)
beta_val = np.random.random_sample(param_shape).astype(np.float32)
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
for use_gpu in [True, False]:
with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name="x")
m = constant_op.constant(m_val, name="m")
v = constant_op.constant(v_val, name="v")
beta = constant_op.constant(beta_val, name="beta")
gamma = constant_op.constant(gamma_val, name="gamma")
keep_dims_m = array_ops.reshape(
m, keep_dims_param_shape, name="keep_dims_m")
keep_dims_v = array_ops.reshape(
v, keep_dims_param_shape, name="keep_dims_v")
keep_dims_beta = array_ops.reshape(
beta, keep_dims_param_shape, name="keep_dims_beta")
keep_dims_gamma = array_ops.reshape(
gamma, keep_dims_param_shape, name="keep_dims_gamma")
epsilon = 0.001
for scale_after_normalization in [True, False]:
for shift_after_normalization in [True, False]:
bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
scale_after_normalization,
shift_after_normalization)
keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v,
keep_dims_beta, keep_dims_gamma,
epsilon,
scale_after_normalization,
shift_after_normalization)
tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
[bn, keep_dims_bn])
self.assertEqual(x_shape, tf_batch_norm.shape)
self.assertEqual(x_shape, keep_dims_tf_batch_norm.shape)
self.assertAllClose(
tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001,
dtype=dtypes.float32,
param_dtype=dtypes.float32):
numpy_dtype = dtype.as_numpy_dtype
numpy_param_dtype = param_dtype.as_numpy_dtype
x_val = np.random.random_sample(x_shape).astype(numpy_dtype)
m_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
v_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
beta_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
gamma_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
for use_gpu in [True, False]:
with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name="x")
m = constant_op.constant(m_val, name="m")
v = constant_op.constant(v_val, name="v")
beta = constant_op.constant(beta_val, name="beta")
gamma = constant_op.constant(gamma_val, name="gamma")
epsilon = 0.001
for scale_after_normalization in [True, False]:
for shift_after_normalization in [True, False]:
bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
scale_after_normalization,
shift_after_normalization)
np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val,
gamma_val, epsilon,
scale_after_normalization,
shift_after_normalization)
[tf_batch_norm] = self.evaluate([bn])
self.assertEqual(x_shape, np_batch_norm.shape)
self.assertEqual(x_shape, tf_batch_norm.shape)
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
def testBatchNormArbitraryShapes(self):
"""Test for a variety of shapes and moments.
Batch normalization is expected to work regardless of the position and
dimensionality of the 'depth' axis/axes.
"""
self._testBatchNormArbitraryShapes((3, 3), (1, 3))
self._testBatchNormArbitraryShapes((3, 3), (3, 1))
self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1))
self._testBatchNormArbitraryShapes(
(2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005)
def testBatchNormMixedPrecision(self):
self._testBatchNormArbitraryShapes((3, 3), (1, 3), dtype=dtypes.float16,
param_dtype=dtypes.float32, atol=0.001)
class SufficientStatisticsTest(test.TestCase):
def _npSuffStats(self, x, axes, shift, keep_dims):
axis = tuple(axes)
if shift is not None:
m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
else:
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
count = 1.0
for d in range(x.ndim):
if d in set(axes):
count *= x.shape[d]
if not keep_dims:
shift = np.asarray(shift)
return count, m_ss, v_ss, shift
def _opSuffStats(self, x, axes, shift, keep_dims):
return nn_impl.sufficient_statistics(x, axes, shift, keep_dims)
def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
x_val = np.random.random_sample(x_shape).astype(np.float32)
np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
for use_gpu in [True, False]:
with self.cached_session(use_gpu=use_gpu) as sess:
if has_shape:
x = constant_op.constant(x_val, name="x")
x.set_shape(x_shape)
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
if shift:
tf_c, tf_m, tf_v, tf_s = self.evaluate([op_c, op_m, op_v, op_s])
else:
tf_c, tf_m, tf_v = self.evaluate([op_c, op_m, op_v])
else:
x = array_ops.placeholder(
dtype=dtypes.float32, shape=[None] * len(x_shape), name="x")
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
if shift:
tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s],
feed_dict={x: x_val})
else:
tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v],
feed_dict={x: x_val})
self.assertAllClose(np_c, tf_c, atol=0.000001)
self.assertAllClose(np_m, tf_m, atol=0.000001)
self.assertAllClose(np_v, tf_v, atol=0.000001)
if shift:
self.assertAllClose(np_s, tf_s, atol=0.000001)
@test_util.run_deprecated_v1
def testSuffStats(self):
for has_shape in [True, False]:
for keep_dims in [True, False]:
for shift in [None, 1.0]:
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
class NormalizeMomentsTest(test.TestCase):
def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
mean = mean_ss / counts
variance = variance_ss / counts - mean * mean
if shift is not None:
mean += shift
return mean, variance
def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift)
def _testNormalizeMoments(self, shape, shift):
counts = np.ones([1]).astype(np.float32)
mean_ss = np.random.random_sample(shape).astype(np.float32)
variance_ss = np.random.random_sample(shape).astype(np.float32)
variance_ss *= variance_ss
if shift:
shift_v = np.random.random_sample(shape).astype(np.float32)
else:
shift_v = None
npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
for use_gpu in [True, False]:
with self.cached_session(use_gpu=use_gpu) as sess:
tf_counts = constant_op.constant(counts, name="counts")
tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss")
tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss")
if shift:
tf_shift_v = constant_op.constant(shift_v, name="shift")
else:
tf_shift_v = None
opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
tf_variance_ss, tf_shift_v)
tfm, tfv = self.evaluate([opm, opv])
self.assertAllClose(npm, tfm, atol=0.000001)
self.assertAllClose(npv, tfv, atol=0.000001)
def testNormalizeMoments(self):
for shift in [None, 4.0]:
self._testNormalizeMoments([3], shift)
self._testNormalizeMoments([2, 3], shift)
class MomentsTest(test.TestCase):
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
# Method to compute moments of `x` wrt `axes`.
#
# This is exposed so WeightedMomentsTest can inherit the tests and
# assertions from MomentsTest; the extra_out_grads argument allows
# its inherited gradient tests to assert gradients against the
# weights as well as the input values.
return nn_impl.moments(x, axes, keep_dims=keep_dims)
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
with self.cached_session():
# shape = [batch, width, height, depth]
assert len(shape) == 4
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = array_ops.placeholder(dtype, shape=[None] * len(shape))
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
num_elements = np.prod([shape[i] for i in axes])
ax = tuple(axes)
expected_mean = np.sum(x_numpy, axis=ax,
keepdims=keep_dims) / num_elements
expected_mean_squared = np.multiply(expected_mean, expected_mean)
expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
axis=ax,
keepdims=keep_dims) / num_elements
expected_variance = expected_x_squared - expected_mean_squared
# Check that the moments are correct.
self.assertAllCloseAccordingToType(
expected_mean, mean.eval(feed_dict={x: x_numpy}))
self.assertAllCloseAccordingToType(
expected_variance, var.eval(feed_dict={x: x_numpy}))
def RunMomentTest(self, shape, axes, keep_dims, dtype):
with self.cached_session():
# shape = [batch, width, height, depth]
assert len(shape) == 4
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype)
# Compute the expected values at high precision since the method
# is prone to catastrophic cancellation:
x_numpy = x_numpy.astype(np.float128)
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
num_elements = np.prod([shape[i] for i in axes])
ax = tuple(axes)
expected_mean = np.sum(x_numpy, axis=ax,
keepdims=keep_dims) / num_elements
expected_mean_squared = np.multiply(expected_mean, expected_mean)
expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
axis=ax,
keepdims=keep_dims) / num_elements
expected_variance = expected_x_squared - expected_mean_squared
# Check that the moments are correct.
self.assertAllCloseAccordingToType(expected_mean, self.evaluate(mean))
self.assertAllCloseAccordingToType(expected_variance, self.evaluate(var))
@test_util.run_deprecated_v1
def testBasic(self):
for keep_dims in [False, True]:
for dtype in [dtypes.float32, dtypes.float16]:
self.RunMomentTest(
shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
self.RunMomentTestWithDynamicShape(
shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
@test_util.run_deprecated_v1
def testGlobalNormalization(self):
for keep_dims in [False, True]:
for dtype in [dtypes.float32, dtypes.float16]:
self.RunMomentTest(
shape=[2, 3, 5, 4],
axes=[0, 1, 2],
keep_dims=keep_dims,
dtype=dtype)
self.RunMomentTestWithDynamicShape(
shape=[2, 3, 5, 4],
axes=[0, 1, 2],
keep_dims=keep_dims,
dtype=dtype)
@test_util.run_deprecated_v1
def testAxes(self):
for keep_dims in [False, True]:
for dtype in [dtypes.float32, dtypes.float16]:
self.RunMomentTest(
shape=[2, 3, 5, 4],
axes=[1, 2, 3],
keep_dims=keep_dims,
dtype=dtype)
self.RunMomentTestWithDynamicShape(
shape=[2, 3, 5, 4],
axes=[1, 2, 3],
keep_dims=keep_dims,
dtype=dtype)
def _testGlobalGradient(self, from_y="mean"):
with self.cached_session():
x_shape = [3, 5, 4, 2]
x_val = np.random.random_sample(x_shape).astype(np.float64)
x = constant_op.constant(x_val)
x.set_shape(x_shape)
axes = [0, 1, 2]
y_shape = [2] # Depth of x
inputs_to_compute_gradients_for = [x]
out_mean, out_var = self._unweighted_moments(
x, axes, extra_out_grads=inputs_to_compute_gradients_for)
if from_y == "mean":
y = out_mean
elif from_y == "var":
y = out_var
for (i, v) in enumerate(inputs_to_compute_gradients_for):
err = gradient_checker.compute_gradient_error(v,
v.get_shape().as_list(),
y, y_shape)
print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
self.assertLess(err, 1e-11)
@test_util.run_deprecated_v1
def testMeanGlobalGradient(self):
self._testGlobalGradient(from_y="mean")
@test_util.run_deprecated_v1
def testVarGlobalGradient(self):
self._testGlobalGradient(from_y="var")
class WeightedMomentsTest(MomentsTest):
"""Tests for nn.weighted_moments.
Note that this test inherits from MomentsTest, inheriting all its
test methods!
It modifies MomentsTest in two ways:
a) By overriding _unweighted_moments, all the codepaths in
MomentsTest are executed, but with calls to tf.nn.moments()
replaced by calls to tf.nn.weighted_moments() with a constant
weight of 1.
b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
this test adds multiple additional calls to
RunWeightedMomentsTest() to exercise correctness with
non-constant weights and varying broadcasting situations. (It
also continues to call MomentsTest.Run(Weighted)?MomentsTest as
well.)
"""
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
weights = constant_op.constant(1, dtype=x.dtype)
if extra_out_grads is not None:
# We want to assert gradients WRT weights as well as X!
extra_out_grads.append(weights)
return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims)
def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
if not dynshapes:
super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims,
dtype)
else:
super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape,
axes,
keep_dims,
dtype)
# 1:1 weights and inputs
self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
# Various broadcasting combinations
for idx in range(len(shape)):
# try broadcasting weights in all positions
weight_shape = [1] * len(shape)
weight_shape[idx] = shape[idx]
self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
# Also try broadcasting with a suffix of length n
weight_shape = shape[-(idx + 1):]
self.RunWeightedMomentTest(
shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
def RunWeightedMomentTest(self,
shape,
weights_shape,
axes,
keep_dims,
dtype,
dynshapes=False):
with self.cached_session() as s:
x_numpy = np.random.normal(size=shape).astype(np.float32)
weights_numpy = np.absolute( # weights must be positive
np.random.normal(
size=weights_shape, loc=1.0).astype(np.float32))
# Expand the numpy version to higher precision
x_numpy = x_numpy.astype(np.float128)
weights_numpy = weights_numpy.astype(np.float128)
x_shape = [None] * len(shape) if dynshapes else shape
weights_shape = ([None] * len(weights_shape) if dynshapes else
weights_shape)
x = array_ops.placeholder(dtype, shape=x_shape)
weights = array_ops.placeholder(dtype, shape=weights_shape)
mean, var = nn_impl.weighted_moments(
x, axes, weights, keep_dims=keep_dims)
ax = tuple(axes)
def _np_weighted_sum(v):
return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
expected_mean = _np_weighted_sum(x_numpy) / weight_sum
expected_mean_squared = np.multiply(expected_mean, expected_mean)
expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) /
weight_sum)
expected_variance = expected_x_squared - expected_mean_squared
mean_v, var_v = s.run([mean, var],
feed_dict={x: x_numpy,
weights: weights_numpy})
self.assertAllCloseAccordingToType(expected_mean, mean_v)
self.assertAllCloseAccordingToType(expected_variance, var_v)
def testAllZeroMasks(self):
x = np.random.normal(size=[8, 3, 4]).astype(np.float32)
weights = np.zeros(shape=[8, 3, 1]).astype(np.float32)
axes = (0, 1)
mean, var = nn_impl.weighted_moments(
x, axes, weights, keep_dims=False)
self.assertAllClose(mean, np.zeros(shape=[4]))
self.assertAllClose(var, np.zeros(shape=[4]))
if __name__ == "__main__":
test.main()