| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for make_template used with MirroredStrategy.""" |
| from tensorflow.python.distribute import distribute_lib |
| from tensorflow.python.distribute import mirrored_strategy |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import template |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| |
| |
| class TemplateMirroredStrategyTest(test.TestCase): |
| |
| @test_util.disable_tfrt("Strategy not supported yet.") |
| def test_merge_call(self): |
| with ops.Graph().as_default(): |
| # The test is testing a v1 only function. |
| if not test.is_gpu_available(): |
| self.skipTest("No GPU available") |
| |
| def fn(): |
| var1 = variable_scope.get_variable( |
| "var1", shape=[], initializer=init_ops.constant_initializer(21.)) |
| distribute_lib.get_replica_context().merge_call(lambda _: ()) |
| var2 = variable_scope.get_variable( |
| "var2", shape=[], initializer=init_ops.constant_initializer(2.)) |
| return var1 * var2 |
| |
| temp = template.make_template("my_template", fn) |
| |
| strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]) |
| out = strategy.experimental_local_results( |
| strategy.run(temp)) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertAllEqual([42., 42.], self.evaluate(out)) |
| |
| |
| if __name__ == "__main__": |
| test.main() |