blob: f0a71a4b920907704b5223404a0c55e3db94a981 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for loss scaling utilities in tensorflow.ops.nn."""
from absl.testing import parameterized
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.platform import test as test_lib
class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
def testComputeAverageLossGlobalBatchSize(self):
per_example_loss = [1, 2, 3, 4, 5]
loss = nn_impl.compute_average_loss(per_example_loss, global_batch_size=10)
self.assertEqual(self.evaluate(loss), 1.5)
def testComputeAverageLossGlobalBatchSize_BatchSizeNonScalar(self):
per_example_loss = [1, 2, 3, 4, 5]
with self.assertRaisesWithPredicateMatch(
ValueError, "global_batch_size must be scalar"):
nn_impl.compute_average_loss(per_example_loss, global_batch_size=[10])
def testComputeAverageLossGlobalBatchSize_BatchSizeFloat(self):
per_example_loss = [1, 2, 3, 4, 5]
with self.assertRaisesWithPredicateMatch(
TypeError, "global_batch_size must be an int"):
nn_impl.compute_average_loss(per_example_loss, global_batch_size=10.0)
def testComputeAverageLossGlobalBatchSize_BatchSizeNegative(self):
per_example_loss = [1, 2, 3, 4, 5]
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
"global_batch_size must be non-negative"):
nn_impl.compute_average_loss(per_example_loss, global_batch_size=-1)
def testComputeAverageLossGlobalBatchSize_BatchSizeZero(self):
per_example_loss = [1, 2, 3, 4, 5]
loss = nn_impl.compute_average_loss(per_example_loss, global_batch_size=0)
self.assertEqual(self.evaluate(loss), 0.0)
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testComputeAverageLossDefaultGlobalBatchSize(self, distribution):
# Without strategy - num replicas = 1
per_example_loss = constant_op.constant([2.5, 6.2, 5.])
loss = nn_impl.compute_average_loss(per_example_loss)
self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3)
# With strategy - num replicas = 2
with distribution.scope():
per_replica_losses = distribution.run(
nn_impl.compute_average_loss, args=(per_example_loss,))
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3)
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testComputeAverageLossDefaultGlobalBatchSizeEmptyBatch(self,
distribution):
per_example_loss = constant_op.constant([], dtypes.float32)
loss = nn_impl.compute_average_loss(per_example_loss)
self.assertEqual(self.evaluate(loss), 0.0)
with distribution.scope():
per_replica_losses = distribution.run(
nn_impl.compute_average_loss, args=(per_example_loss,))
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), 0.0)
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testComputeAverageLossSampleWeights(self, distribution):
with distribution.scope():
# Scalar sample weight
per_replica_losses = distribution.run(
nn_impl.compute_average_loss,
args=([2., 4., 6.],),
kwargs={"sample_weight": 2})
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), (2. + 4. + 6.) * 2. / 3)
# Per example sample weight
per_replica_losses = distribution.run(
nn_impl.compute_average_loss,
args=([2., 4., 6.],),
kwargs={"sample_weight": [0.3, 0.5, 0.2]})
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(
self.evaluate(loss), (2. * 0.3 + 4. * 0.5 + 6. * 0.2) / 3)
# Time-step sample weight
per_replica_losses = distribution.run(
nn_impl.compute_average_loss,
args=([[2., 0.5], [4., 1.]],),
kwargs={"sample_weight": [[0.3, 0.7], [0.2, 0.8]]})
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(
self.evaluate(loss), (2. * 0.3 + 0.5 * 0.7 + 4. * 0.2 + 1. * 0.8) / 2)
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testComputeAverageLossSampleWeightsEmptyBatch(self, distribution):
empty_rank0 = constant_op.constant([], dtypes.float32)
with distribution.scope():
# Scalar sample weight
per_replica_losses = distribution.run(
nn_impl.compute_average_loss,
args=(empty_rank0,),
kwargs={"sample_weight": 2})
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), 0.0)
# Per example sample weight
per_replica_losses = distribution.run(
nn_impl.compute_average_loss,
args=(empty_rank0,),
kwargs={"sample_weight": empty_rank0})
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(
self.evaluate(loss), 0.0)
def testComputeAverageLossInvalidSampleWeights(self):
with self.assertRaisesIncompatibleShapesError(
(ValueError, errors_impl.InvalidArgumentError)):
nn_impl.compute_average_loss([2.5, 6.2, 5.],
sample_weight=[0.2, 0.8],
global_batch_size=10)
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testComputeAverageLossDtype(self, distribution):
with distribution.scope():
per_example_loss = constant_op.constant([2., 4., 6.],
dtype=dtypes.float64)
per_replica_losses = distribution.run(
nn_impl.compute_average_loss,
args=(per_example_loss,),
kwargs={"sample_weight": 2})
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertEqual(loss.dtype, dtypes.float64)
def testComputeAverageLossInvalidRank(self):
per_example_loss = constant_op.constant(2.)
# Static rank
with self.assertRaisesRegex(
ValueError, "Invalid value passed for `per_example_loss`. "
"Expected a tensor with at least rank 1."):
nn_impl.compute_average_loss(per_example_loss)
with context.graph_mode():
# Dynamic rank
per_example_loss = array_ops.placeholder(dtype=dtypes.float32)
loss = nn_impl.compute_average_loss(per_example_loss)
with self.cached_session() as sess:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Invalid value passed for `per_example_loss`. "
"Expected a tensor with at least rank 1."):
sess.run(loss, {per_example_loss: 2})
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testComputeAverageLossInCrossReplicaContext(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
RuntimeError,
"You are calling `compute_average_loss` in cross replica context"):
nn_impl.compute_average_loss([2, 3])
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testScaleRegularizationLoss(self, distribution):
# Without strategy - num replicas = 1
reg_losses = constant_op.constant([2.5, 6.2, 5.])
loss = nn_impl.scale_regularization_loss(reg_losses)
self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.))
# With strategy - num replicas = 2
with distribution.scope():
per_replica_losses = distribution.run(
nn_impl.scale_regularization_loss, args=(reg_losses,))
loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.))
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],
mode=["graph", "eager"],
)
)
def testScaleRegularizationLossInCrossReplicaContext(self, distribution):
with distribution.scope():
with self.assertRaisesRegex(
RuntimeError, "You are calling `scale_regularization_loss` in "
"cross replica context"):
nn_impl.scale_regularization_loss([2, 3])
if __name__ == "__main__":
test_lib.main()