blob: d6f5a7b916f65371162b6a7dfd25709b00805eaa [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test DistributionStrategy in the zero batch case."""
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
class NormalizationTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
],
mode=["graph"],
fused=[True, False]))
def testBNWithZeroBatchInputGraph(self, distribution, fused):
distribution.extended.experimental_enable_get_next_as_optional = True
with distribution.scope(), self.cached_session() as sess:
bn_list = []
inputs = np.random.random((0, 4, 4, 3)) + 100
targets = np.random.random((0, 4, 4, 3))
inputs_placeholder = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, 4, 4, 3])
targets_placeholder = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, 4, 4, 3])
def step_fn(is_training, inputs, targets=None):
bn = normalization.BatchNormalization(
axis=3, epsilon=1e-3, momentum=0.9, fused=fused)
bn_list.append(bn)
outputs = bn.apply(inputs, training=is_training)
if not is_training:
return outputs
loss = losses.mean_squared_error(targets, outputs)
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
train_op = distribution.extended.call_for_each_replica(
step_fn, args=(True, inputs_placeholder, targets_placeholder))
predict_op = distribution.extended.call_for_each_replica(
step_fn, args=(False, inputs_placeholder))
bn = bn_list[0]
self.evaluate(variables.global_variables_initializer())
# Check for initial statistics and weights.
moving_mean, moving_var = self.evaluate(
[bn.moving_mean, bn.moving_variance])
self.assertAllEqual([0, 0, 0], moving_mean)
self.assertAllEqual([1, 1, 1], moving_var)
np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
self.assertAllEqual([1, 1, 1], np_gamma)
self.assertAllEqual([0, 0, 0], np_beta)
for _ in range(100):
np_output, _, _ = sess.run([train_op] + bn.updates, {
inputs_placeholder: inputs,
targets_placeholder: targets
})
self.assertEqual(0.0, np_output)
# Verify that the statistics and weights are not changed after training.
moving_mean, moving_var = self.evaluate(
[bn.moving_mean, bn.moving_variance])
self.assertAllEqual([0, 0, 0], moving_mean)
self.assertAllEqual([1, 1, 1], moving_var)
np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
self.assertAllEqual([1, 1, 1], np_gamma)
self.assertAllEqual([0, 0, 0], np_beta)
# Test inference.
np_output = sess.run(predict_op, {inputs_placeholder: inputs})
self.assertEqual([], np_output.tolist())
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
],
mode=["eager"],
fused=[True, False]))
def testBNWithZeroBatchInput(self, distribution, fused):
distribution.extended.experimental_enable_get_next_as_optional = True
with distribution.scope():
inputs = np.random.random((0, 4, 4, 3)).astype(np.float32) + 100
targets = np.random.random((0, 4, 4, 3)).astype(np.float32)
bn = normalization.BatchNormalization(
axis=3, epsilon=1e-3, momentum=0.9, fused=fused)
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
@def_function.function
def train_step():
def step_fn(inputs, targets):
with backprop.GradientTape() as tape:
outputs = bn.apply(inputs, training=True)
loss = losses.mean_squared_error(targets, outputs)
grads = tape.gradient(loss, bn.variables)
optimizer.apply_gradients(zip(grads, bn.variables))
return loss
return distribution.run(step_fn, args=(inputs, targets))
for _ in range(100):
np_output = train_step().numpy()
self.assertEqual(0.0, np_output)
# Verify that the statistics and weights are not changed after training.
self.assertAllEqual([0, 0, 0], bn.moving_mean.numpy())
self.assertAllEqual([1, 1, 1], bn.moving_variance.numpy())
self.assertAllEqual([1, 1, 1], bn.gamma.numpy())
self.assertAllEqual([0, 0, 0], bn.beta.numpy())
@def_function.function
def test_step():
def step_fn(inputs):
outputs = bn.apply(inputs, training=False)
return outputs
return distribution.run(step_fn, args=(inputs,))
# Test inference.
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
test_step().numpy())
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
],
mode=["eager"],
fused=[True, False]))
def testBNWithDynamicBatchInputEager(self, distribution, fused):
distribution.extended.experimental_enable_get_next_as_optional = True
with distribution.scope():
# Explicitly create dataset with drop_remainder=False.
# This would make batch size unknown.
inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
10, drop_remainder=False).repeat()
dataset_iterator = iter(
distribution.experimental_distribute_dataset(dataset))
bn = normalization.BatchNormalization(
axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
@def_function.function
def train_step(iterator):
def step_fn(inputs):
features, targets = inputs
with backprop.GradientTape() as tape:
outputs = bn(features, training=True)
loss = losses.mean_squared_error(targets, outputs)
grads = tape.gradient(loss, bn.variables)
optimizer.apply_gradients(zip(grads, bn.variables))
return loss
return distribution.run(step_fn, args=(next(iterator),))
for _ in range(100):
train_step(dataset_iterator).numpy()
# Verify that the statistics and weights are updated.
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
if __name__ == "__main__":
test_util.main()