| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Test cases for ftrl ("follow the regularized leader") operations.""" |
| |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.platform import googletest |
| from tensorflow.python.training import training_ops |
| |
| |
| class ResourceApplyFtrlTest(xla_test.XLATestCase): |
| """Test cases for ftrl ops.""" |
| |
| def setUp(self): |
| super().setUp() |
| self.rewrite_ops_for_tpu = ("TPU" in self.device and |
| test_util.is_mlir_bridge_enabled()) |
| |
| def _eval(self, var, accum, linear, grad, lr, l1, l2, l2_shrinkage=0, |
| lr_power=1, multiply_linear_by_lr=False): |
| dtype = np.float32 |
| var = np.array(var, dtype=dtype) |
| accum = np.array(accum, dtype=dtype) |
| linear = np.array(linear, dtype=dtype) |
| grad = np.array(grad, dtype=dtype) |
| use_v2 = bool(l2_shrinkage) |
| with self.session() as session: |
| with self.test_scope(): |
| lr = constant_op.constant(lr, dtype=dtype) |
| l1 = constant_op.constant(l1, dtype=dtype) |
| l2 = constant_op.constant(l2, dtype=dtype) |
| l2_shrinkage = constant_op.constant(l2_shrinkage, dtype=dtype) |
| lr_power = constant_op.constant(lr_power, dtype=dtype) |
| v_var = resource_variable_ops.ResourceVariable(var, dtype=dtype) |
| v_accum = resource_variable_ops.ResourceVariable(accum, dtype=dtype) |
| v_linear = resource_variable_ops.ResourceVariable(linear, dtype=dtype) |
| session.run(v_var.create) |
| session.run(v_accum.create) |
| session.run(v_linear.create) |
| assert not (use_v2 and multiply_linear_by_lr) |
| if use_v2: |
| session.run(training_ops.resource_apply_ftrl_v2( |
| v_var.handle, v_accum.handle, v_linear.handle, |
| grad, lr, l1, l2, l2_shrinkage, lr_power, |
| multiply_linear_by_lr=multiply_linear_by_lr)) |
| else: |
| session.run(training_ops.resource_apply_ftrl( |
| v_var.handle, v_accum.handle, v_linear.handle, |
| grad, lr, l1, l2, lr_power, |
| multiply_linear_by_lr=multiply_linear_by_lr)) |
| return (v_var.read_value().eval().reshape(var.shape), |
| v_accum.read_value().eval().reshape(accum.shape), |
| v_linear.read_value().eval().reshape(linear.shape)) |
| |
| def testAccum(self): |
| """Test that accum is updated with grad^2.""" |
| accum = np.array([[[1, 3], [2, 5], [6, 8]]]) |
| grad = np.array([[[1, 3], [2, 5], [6, 8]]]) |
| _, new_accum, _ = self._eval( |
| var=np.zeros((1, 3, 2)), |
| accum=accum, |
| linear=np.zeros((1, 3, 2)), |
| grad=grad, |
| lr=7, l1=3, l2=7, lr_power=2) |
| self.assertAllClose(accum + grad*grad, new_accum) |
| |
| def testLinearNoGradient(self): |
| """Test that if accum_new == accum, linear doesn't change.""" |
| _, _, linear = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=[[[1, 3], [2, 5], [6, 8]]], |
| linear=[[[1, 2], [3, 4], [5, 6]]], |
| grad=np.zeros((1, 3, 2)), # make accum_new == acum |
| lr=1, l1=3, l2=7, lr_power=2) |
| self.assertAllClose([[[1, 2], [3, 4], [5, 6]]], linear) |
| |
| def testLinear(self): |
| """Test the linear update for new_linear=2 and linear=1.""" |
| _, _, linear = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=1, l1=3, l2=7, lr_power=2) |
| self.assertAllClose(1.75 * np.ones((1, 3, 2)), linear) |
| |
| def testLR(self): |
| """Test that the linear update is divided by lr.""" |
| _, _, linear = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=5, l1=3, l2=7, lr_power=-1) |
| self.assertAllClose(0.8 * np.ones((1, 3, 2)), linear) |
| |
| def testVar(self): |
| """Test computation of var with linear=1.5, quadratic=1.""" |
| var, _, _ = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=1, l1=1, l2=0.25, lr_power=1) |
| self.assertAllClose(-0.5 * np.ones((1, 3, 2)), var) |
| |
| def testVarClipped(self): |
| """Test that var becomes 0 if |linear| < l1.""" |
| var, _, _ = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=1, l1=1.6, l2=0.25, lr_power=1) |
| self.assertAllClose(np.zeros((1, 3, 2)), var) |
| |
| def testQuadratic(self): |
| """Test that quadratic (here: -2) is the divisor of var.""" |
| var, _, _ = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=1, l1=1, l2=-1.25, lr_power=1) |
| self.assertAllClose(0.25 * np.ones((1, 3, 2)), var) |
| |
| def testL2Shrinkage(self): |
| """Test that 2 * l2_shrinkage * var is *not* added to the gradient.""" |
| _, accum, _ = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.zeros((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.zeros((1, 3, 2)), |
| lr=7, l1=3, l2=7, lr_power=2, l2_shrinkage=0.5) |
| self.assertAllClose(np.zeros((1, 3, 2)), accum) |
| |
| def testL2ShrinkageOnLinear(self): |
| """Test that 2 * l2_shrinkage * var is added to linear.""" |
| _, _, linear = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.zeros((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.zeros((1, 3, 2)), |
| lr=2, l1=3, l2=7, lr_power=0, l2_shrinkage=11) |
| self.assertAllClose(22 * np.ones((1, 3, 2)), linear) |
| |
| def testMultiplyLinearByLR(self): |
| """Test multiply_linear_by_lr = true for the linear variable.""" |
| _, _, linear = self._eval( |
| var=np.zeros((1, 3, 2)), |
| accum=np.zeros((1, 3, 2)), |
| linear=np.ones((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=6, l1=1, l2=-1.25, lr_power=0, |
| multiply_linear_by_lr=True) |
| self.assertAllClose(7 * np.ones((1, 3, 2)), linear) |
| |
| def testMultiplyLinearByLRClipping(self): |
| """Test that multiply_linear_by_lr = true scales the clip margins.""" |
| var, _, _ = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=3, l1=1.0, l2=0.25, lr_power=1, |
| multiply_linear_by_lr=True) |
| self.assertAllClose(-0.25 * np.ones((1, 3, 2)), var) |
| |
| def testMultiplyLinearByLRClipZero(self): |
| """Test that multiply_linear_by_lr = true still clips to 0.""" |
| var, _, _ = self._eval( |
| var=np.ones((1, 3, 2)), |
| accum=np.ones((1, 3, 2)), |
| linear=np.zeros((1, 3, 2)), |
| grad=np.ones((1, 3, 2)), |
| lr=3, l1=1.2, l2=0.25, lr_power=1, |
| multiply_linear_by_lr=True) |
| self.assertAllClose(np.zeros((1, 3, 2)), var) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |