| # coding=utf-8 |
| # Copyright 2020 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Tests for compiler_opt.rl.local_data_collector.""" |
| |
| import collections |
| |
| import multiprocessing as mp |
| import string |
| import subprocess |
| from unittest import mock |
| |
| import tensorflow as tf |
| from tf_agents.system import system_multiprocessing as multiprocessing |
| |
| from compiler_opt.rl import compilation_runner |
| from compiler_opt.rl import data_collector |
| from compiler_opt.rl import local_data_collector |
| from google.protobuf import text_format |
| |
| |
| def _get_sequence_example(feature_value): |
| sequence_example_text = string.Template(""" |
| feature_lists { |
| feature_list { |
| key: "feature_0" |
| value { |
| feature { int64_list { value: $feature_value } } |
| feature { int64_list { value: $feature_value } } |
| } |
| } |
| }""").substitute(feature_value=feature_value) |
| return text_format.Parse(sequence_example_text, tf.train.SequenceExample()) |
| |
| |
| def mock_collect_data(file_paths, tf_policy_dir, reward_stat): |
| assert file_paths == ('a', 'b') |
| assert tf_policy_dir == 'policy' |
| assert reward_stat is None or reward_stat == { |
| 'default': |
| compilation_runner.RewardStat( |
| default_reward=1, moving_average_reward=2) |
| } |
| if reward_stat is None: |
| return compilation_runner.CompilationResult( |
| sequence_examples=[_get_sequence_example(feature_value=1)], |
| reward_stats={ |
| 'default': |
| compilation_runner.RewardStat( |
| default_reward=1, moving_average_reward=2) |
| }, |
| rewards=[1.2], |
| keys=['default']) |
| else: |
| return compilation_runner.CompilationResult( |
| sequence_examples=[_get_sequence_example(feature_value=2)], |
| reward_stats={ |
| 'default': |
| compilation_runner.RewardStat( |
| default_reward=1, moving_average_reward=3) |
| }, |
| rewards=[3.4], |
| keys=['default']) |
| |
| |
| class Sleeper(compilation_runner.CompilationRunner): |
| """Test CompilationRunner that just sleeps.""" |
| |
| def collect_data(self, file_paths, tf_policy_path, reward_stat): |
| _ = file_paths, tf_policy_path, reward_stat |
| compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600, |
| self._cancellation_manager) |
| return compilation_runner.CompilationResult( |
| sequence_examples=[], reward_stats={}, rewards=[], keys=[]) |
| |
| |
| class LocalDataCollectorTest(tf.test.TestCase): |
| |
| def test_local_data_collector(self): |
| |
| def make_runner(): |
| |
| class MyRunner(compilation_runner.CompilationRunner): |
| |
| def collect_data(self, *args, **kwargs): |
| return mock_collect_data(*args, **kwargs) |
| |
| return MyRunner() |
| |
| def create_test_iterator_fn(): |
| |
| def _test_iterator_fn(data_list): |
| assert data_list in ( |
| [_get_sequence_example(feature_value=1).SerializeToString()] * 9, |
| [_get_sequence_example(feature_value=2).SerializeToString()] * 9) |
| if data_list == [ |
| _get_sequence_example(feature_value=1).SerializeToString() |
| ] * 9: |
| return iter(tf.data.Dataset.from_tensor_slices([1, 2, 3])) |
| else: |
| return iter(tf.data.Dataset.from_tensor_slices([4, 5, 6])) |
| |
| return _test_iterator_fn |
| |
| collector = local_data_collector.LocalDataCollector( |
| file_paths=tuple([('a', 'b')] * 100), |
| num_workers=4, |
| num_modules=9, |
| parser=create_test_iterator_fn(), |
| worker_ctor=make_runner, |
| reward_stat_map=collections.defaultdict(lambda: None)) |
| |
| data_iterator, monitor_dict = collector.collect_data(policy_path='policy') |
| data = list(data_iterator) |
| self.assertEqual([1, 2, 3], data) |
| expected_monitor_dict_subset = { |
| 'default': { |
| 'success_modules': 9, |
| 'total_trajectory_length': 18, |
| } |
| } |
| self.assertDictContainsSubset(expected_monitor_dict_subset, monitor_dict) |
| |
| data_iterator, monitor_dict = collector.collect_data(policy_path='policy') |
| data = list(data_iterator) |
| self.assertEqual([4, 5, 6], data) |
| expected_monitor_dict_subset = { |
| 'default': { |
| 'success_modules': 9, |
| 'total_trajectory_length': 18, |
| } |
| } |
| self.assertDictContainsSubset(expected_monitor_dict_subset, monitor_dict) |
| |
| collector.close_pool() |
| |
| def test_local_data_collector_task_management(self): |
| |
| def parser(_): |
| pass |
| |
| class QuickExiter(data_collector.EarlyExitChecker): |
| |
| def __init__(self, num_modules): |
| data_collector.EarlyExitChecker.__init__(self, num_modules=num_modules) |
| |
| def wait(self, _): |
| return False |
| |
| collector = local_data_collector.LocalDataCollector( |
| file_paths=tuple([('a', 'b')] * 200), |
| num_workers=4, |
| num_modules=4, |
| worker_ctor=Sleeper, |
| parser=parser, |
| reward_stat_map=collections.defaultdict(lambda: None), |
| exit_checker_ctor=QuickExiter) |
| collector.collect_data(policy_path='policy') |
| collector.join_pending_jobs() |
| killed = 0 |
| for _, w in collector.get_last_work(): |
| self.assertRaises(compilation_runner.ProcessKilledError, w.result) |
| killed += 1 |
| self.assertEquals(killed, 4) |
| collector.close_pool() |
| |
| |
| if __name__ == '__main__': |
| multiprocessing.handle_test_main(tf.test.main) |