| # coding=utf-8 |
| # Copyright 2020 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Tests for generate_default_trace.""" |
| import json |
| import os |
| from unittest import mock |
| |
| from absl import flags |
| from absl.testing import absltest |
| from absl.testing import flagsaver |
| import gin |
| import tensorflow as tf |
| |
| # This is https://github.com/google/pytype/issues/764 |
| from google.protobuf import text_format # pytype: disable=pyi-error |
| from compiler_opt.rl import compilation_runner |
| from compiler_opt.tools import generate_default_trace |
| |
| flags.FLAGS['num_workers'].allow_override = True |
| flags.FLAGS['gin_files'].allow_override = True |
| flags.FLAGS['gin_bindings'].allow_override = True |
| |
| |
| class MockCompilationRunner(compilation_runner.CompilationRunner): |
| """A compilation runner just for test.""" |
| |
| def collect_data(self, loaded_module_spec, policy, reward_stat): |
| sequence_example_text = """ |
| feature_lists { |
| feature_list { |
| key: "feature_0" |
| value { |
| feature { int64_list { value: 1 } } |
| feature { int64_list { value: 1 } } |
| } |
| } |
| }""" |
| sequence_example = text_format.Parse(sequence_example_text, |
| tf.train.SequenceExample()) |
| |
| return compilation_runner.CompilationResult( |
| sequence_examples=[sequence_example], |
| reward_stats={ |
| 'default': |
| compilation_runner.RewardStat( |
| default_reward=1, moving_average_reward=2) |
| }, |
| rewards=[1.2], |
| policy_rewards=[18], |
| keys=['default']) |
| |
| |
| class GenerateDefaultTraceTest(absltest.TestCase): |
| |
| def setUp(self): |
| with gin.unlock_config(): |
| gin.parse_config_files_and_bindings( |
| config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'], |
| bindings=None) |
| return super().setUp() |
| |
| @mock.patch('compiler_opt.tools.generate_default_trace.get_runner') |
| def test_api(self, mock_get_runner): |
| |
| tmp_dir = self.create_tempdir() |
| module_names = ['a', 'b', 'c', 'd'] |
| |
| with tf.io.gfile.GFile( |
| os.path.join(tmp_dir.full_path, 'corpus_description.json'), 'w') as f: |
| json.dump({'modules': module_names, 'has_thinlto': False}, f) |
| |
| for module_name in module_names: |
| with tf.io.gfile.GFile( |
| os.path.join(tmp_dir.full_path, module_name + '.bc'), 'w') as f: |
| f.write(module_name) |
| |
| with tf.io.gfile.GFile( |
| os.path.join(tmp_dir.full_path, module_name + '.cmd'), 'w') as f: |
| f.write('-cc1') |
| |
| mock_compilation_runner = MockCompilationRunner() |
| mock_get_runner.return_value = mock_compilation_runner |
| |
| with flagsaver.flagsaver( |
| data_path=tmp_dir.full_path, |
| num_workers=2, |
| output_path=os.path.join(tmp_dir.full_path, 'output'), |
| output_performance_path=os.path.join(tmp_dir.full_path, |
| 'output_performance'), |
| ): |
| generate_default_trace.main(None) |
| |
| def test_get_runner(self): |
| runner = generate_default_trace.get_runner() |
| self.assertIsInstance(runner, compilation_runner.CompilationRunner) |
| |
| |
| if __name__ == '__main__': |
| absltest.main() |