blob: a70aaf2133ac88696bb80f903c3f220e861209b2 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.local_data_collector."""
import time
import tensorflow as tf
from tf_agents.system import system_multiprocessing as multiprocessing
from compiler_opt.rl import local_data_collector
class LocalDataCollectorTest(tf.test.TestCase):
def test_local_data_collector(self):
def mock_collect_data(file_paths, tf_policy_dir, default_policy_size,
moving_average_size):
del moving_average_size
assert file_paths == ('a', 'b')
assert tf_policy_dir == 'policy'
assert default_policy_size is None or default_policy_size == 1
if default_policy_size is None:
return 1, 1, 2, 3
else:
return 2, 1, 2, 3
def create_test_iterator_fn():
def _test_iterator_fn(data_list):
assert data_list in ([1] * 9, [2] * 9)
if data_list == [1] * 9:
return iter(tf.data.Dataset.from_tensor_slices([1, 2, 3]))
else:
return iter(tf.data.Dataset.from_tensor_slices([4, 5, 6]))
return _test_iterator_fn
collector = local_data_collector.LocalDataCollector(
file_paths=tuple([('a', 'b')] * 100),
num_workers=4,
num_modules=9,
runner=mock_collect_data,
parser=create_test_iterator_fn())
data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
data = list(data_iterator)
self.assertEqual([1, 2, 3], data)
expected_monitor_dict_subset = {'default': {'success_modules': 9}}
self.assertDictContainsSubset(expected_monitor_dict_subset, monitor_dict)
data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
data = list(data_iterator)
self.assertEqual([4, 5, 6], data)
expected_monitor_dict_subset = {'default': {'success_modules': 9}}
self.assertDictContainsSubset(expected_monitor_dict_subset, monitor_dict)
collector.close_pool()
def test_local_data_collector_task_management(self):
class OverloadHandler:
def __init__(self):
self.counts = []
def reset(self):
self.counts.clear()
def handler(self, count):
self.counts.append(count)
def long_running_collector(file_path, *_):
_, t = file_path.split('_')
# avoid lint warnings
t = int(t)
time.sleep(t)
return file_path, t, t, t
def parser(data_list):
assert data_list
overload_handler = OverloadHandler()
# Set the max_unfinished_tasks so we may schedule first some very long
# running work that occupies some, but not all the worker processes of the
# pool. This ensures there are workers able to pick up the short-running
# work and clear it.
collector = local_data_collector.LocalDataCollector(
file_paths=['wait_0' for _ in range(0, 200)],
num_workers=4,
num_modules=4,
runner=long_running_collector,
parser=parser,
max_unfinished_tasks=2,
overload_handler=overload_handler.handler) # pytype: disable=wrong-arg-types
collector.collect_data(policy_path='policy')
while [r for _, r in collector._unfinished_work if not r.ready()]:
time.sleep(1)
collector.inject_unfinished_work_for_test([
('policy', r) for r in collector._schedule_jobs(
'policy', ['wait_5', 'wait_5', 'wait_10'])
])
collector.collect_data(policy_path='policy')
self.assertNotEmpty(overload_handler.counts)
# We set the overload threshold (_max_unfinished_tasks) at 2, so the
# overload handler should have seen a '3' after the short running tasks have
# cleared.
self.assertIn(3, overload_handler.counts)
# The really long running task would not have cleared yet.
self.assertLen(
[r for _, r in collector.unfinished_work if not r.ready()], 1)
collector.close_pool()
if __name__ == '__main__':
multiprocessing.handle_test_main(tf.test.main)