blob: 423612e0d70957a53e28b5393b1fbd97f6c870c2 [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.env."""
import io
import contextlib
import ctypes
import multiprocessing
import time
from unittest import mock
import subprocess
import os
import tempfile
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
import numpy as np
from compiler_opt.rl import env, log_reader
from compiler_opt.rl import corpus
from compiler_opt.rl import log_reader_test
_CLANG_PATH = '/test/clang/path'
_MOCK_MODULE = corpus.LoadedModuleSpec(
name='module',
loaded_ir=b'asdf',
orig_options=('--opt_a', 'a', '--opt_b', 'b'),
)
_NUM_STEPS = 10
class MockTask(env.MLGOTask):
"""Implementation of mock task for testing."""
def get_cmdline(self, clang_path: str, base_args: list[str],
interactive_base_path: str | None,
working_dir: str) -> list[str]:
if interactive_base_path:
interactive_args = [
f'--interactive={interactive_base_path}',
]
else:
interactive_args = []
return [clang_path] + base_args + interactive_args
def get_module_scores(self, working_dir: str) -> dict[str, float]:
return {'default': 47}
# This mocks subprocess.Popen for interactive clang sessions
@contextlib.contextmanager
def mock_interactive_clang(cmdline, stderr, stdout):
del stderr
del stdout
# do basic argument parsing
fname = None
for arg in cmdline:
if arg.startswith('--interactive='):
fname = arg[len('--interactive='):]
break
class MockProcess:
def wait(self, timeout):
pass
def kill(self):
pass
if not fname:
yield MockProcess()
return
# Create the fds for the pipes
# (the env doesn't create the files, it assumes they are opened by clang)
with io.FileIO(fname + '.out', 'wb+') as f_out:
with io.FileIO(fname + '.in', 'rb+') as f_in:
del f_in
writer = log_reader_test.LogTestExampleBuilder(opened_file=f_out)
# Write the header describing the features/rewards
writer.write_header({
'features': [{
'name': 'times_called',
'port': 0,
'shape': [1],
'type': 'int64_t',
},],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float',
},
})
writer.write_newline()
class MockInteractiveProcess(MockProcess):
"""Mock clang interactive process that writes the log."""
def __init__(self):
self._counter = 0
# We poll the process at every call to get_observation to ensure the
# clang process is still alive. So here, each time poll() is called,
# write a new context
def poll(self):
if self._counter >= _NUM_STEPS:
f_out.close()
return None
example_writer = log_reader_test.LogTestExampleBuilder(
opened_file=f_out)
example_writer.write_context_marker(f'context_{self._counter}')
example_writer.write_observation_marker(0)
example_writer.write_buff([self._counter], ctypes.c_int64)
example_writer.write_newline()
example_writer.write_outcome_marker(0)
example_writer.write_buff([3.14], ctypes.c_float)
example_writer.write_newline()
self._counter += 1
return None
yield MockInteractiveProcess()
class ClangSessionTest(tf.test.TestCase):
@mock.patch('subprocess.Popen')
def test_clang_session(self, mock_popen):
mock_task = MockTask()
with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask,
interactive=False) as clang_session:
del clang_session
cmdline = mock_task.get_cmdline(_CLANG_PATH,
list(_MOCK_MODULE.orig_options), None,
'/tmp/mock/tmp/file')
mock_popen.assert_called_once_with(
cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
@mock.patch('subprocess.Popen')
def test_interactive_clang_session(self, mock_popen):
mock_popen.side_effect = mock_interactive_clang
with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session:
for idx in range(_NUM_STEPS):
obs = clang_session.get_observation()
self.assertEqual(
obs.obs['times_called'],
np.array([idx], dtype=np.int64),
)
self.assertEqual(obs.context, f'context_{idx}')
mock_popen.assert_called_once()
@mock.patch('subprocess.Popen')
def test_interactive_clang_temp_dir(self, mock_popen):
mock_popen.side_effect = mock_interactive_clang
working_dir = None
with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session:
for _ in range(_NUM_STEPS):
obs = clang_session.get_observation()
working_dir = obs.working_dir
self.assertEqual(os.path.exists(working_dir), True)
self.assertEqual(os.path.exists(working_dir), False)
with tempfile.TemporaryDirectory() as td:
with flagsaver.flagsaver(
(env.compilation_runner._EXPLICIT_TEMPS_DIR, td)): # pylint: disable=protected-access
with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask,
interactive=True) as clang_session:
for _ in range(_NUM_STEPS):
obs = clang_session.get_observation()
working_dir = obs.working_dir
self.assertEqual(os.path.exists(working_dir), True)
self.assertEqual(os.path.exists(working_dir), True)
class PipelineCommsTest(parameterized.TestCase):
@parameterized.named_parameters(
{
'testcase_name': 'write',
'method': 'open_write_pipe'
}, {
'testcase_name': 'read',
'method': 'open_read_pipe'
})
def test_pipe_timeout_open(self, method):
slept = multiprocessing.Event()
def _sleep():
time.sleep(3600)
slept.set()
with tempfile.TemporaryDirectory() as td:
fname = os.path.join(td, 'something')
os.mkfifo(fname, 0o666)
proc = multiprocessing.Process(target=_sleep)
proc.start()
with self.assertRaises(TimeoutError):
with getattr(env, method)(fname, timeout=5):
self.fail()
proc.kill()
proc.join()
self.assertFalse(slept.is_set())
@parameterized.named_parameters({
'testcase_name': 'read',
'method': 'read'
}, {
'testcase_name': 'readline',
'method': 'readline'
})
def test_read_pipe_timeout_read(self, method):
slept = multiprocessing.Event()
def _open_then_sleep(fname):
with open(fname, 'wb'):
time.sleep(3600)
slept.set()
with tempfile.TemporaryDirectory() as td:
fname = os.path.join(td, 'something')
os.mkfifo(fname, 0o666)
proc = multiprocessing.Process(target=_open_then_sleep, args=(fname,))
proc.start()
with env.open_read_pipe(fname, timeout=5) as read_pipe:
with self.assertRaises(TimeoutError):
getattr(read_pipe, method)()
proc.kill()
proc.join()
self.assertFalse(slept.is_set())
def test_write_pipeline_timeout_open(self):
slept = multiprocessing.Event()
def _sleep():
time.sleep(3600)
slept.set()
with tempfile.TemporaryDirectory() as td:
fname = os.path.join(td, 'something')
os.mkfifo(fname, 0o666)
proc = multiprocessing.Process(target=_sleep)
proc.start()
with self.assertRaises(TimeoutError):
with env.open_write_pipe(fname, timeout=5):
self.fail()
proc.kill()
proc.join()
self.assertFalse(slept.is_set())
def test_process_fails_to_open_writer(self):
slept = multiprocessing.Event()
def _sleep():
time.sleep(3600)
slept.set()
with tempfile.TemporaryDirectory() as td:
reader = os.path.join(td, 'reader')
writer = os.path.join(td, 'writer')
os.mkfifo(reader, 0o666)
os.mkfifo(writer, 0o666)
proc = multiprocessing.Process(target=_sleep)
proc.start()
with self.assertRaises(TimeoutError):
with env.interactive_session(
reader_name=reader, writer_name=writer, timeout=5):
self.fail()
self.assertFalse(slept.is_set())
proc.kill()
proc.join()
def test_process_fails_to_answer(self):
post_sleep_event = multiprocessing.Event()
opened_event = multiprocessing.Event()
with tempfile.TemporaryDirectory() as td:
reader = os.path.join(td, 'reader')
writer = os.path.join(td, 'writer')
os.mkfifo(reader, 0o666)
os.mkfifo(writer, 0o666)
def _the_process():
with open(writer, 'rb'):
opened_event.set()
pass
time.sleep(3600)
post_sleep_event.set()
proc = multiprocessing.Process(target=_the_process)
proc.start()
with self.assertRaises(TimeoutError):
with env.interactive_session(
reader_name=reader, writer_name=writer, timeout=5):
pass
self.assertTrue(opened_event.is_set())
self.assertFalse(post_sleep_event.is_set())
proc.kill()
proc.join()
def test_process_quits_midway(self):
post_sleep_event = multiprocessing.Event()
opened_event = multiprocessing.Event()
wrote_event = multiprocessing.Event()
with tempfile.TemporaryDirectory() as td:
reader = os.path.join(td, 'reader')
writer = os.path.join(td, 'writer')
os.mkfifo(reader, 0o666)
os.mkfifo(writer, 0o666)
def _the_process():
with open(writer, 'rb'):
with open(reader, 'wb') as out:
opened_event.set()
w = log_reader_test.LogTestExampleBuilder(opened_file=out)
w.write_header({
'features': [{
'name': 'times_called',
'port': 0,
'shape': [1],
'type': 'int64_t',
},],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float',
}
})
w.write_newline()
w.write_context_marker('hello')
w.write_observation_marker(0)
w.write_buff([1], ctypes.c_int16)
out.flush()
wrote_event.set()
time.sleep(3600)
post_sleep_event.set()
proc = multiprocessing.Process(target=_the_process)
proc.start()
with env.interactive_session(
reader_name=reader, writer_name=writer, timeout=10) as (read_pipe, _):
with self.assertRaises(IOError):
for _ in log_reader.read_log_from_file(read_pipe):
self.fail()
self.assertTrue(opened_event.is_set())
self.assertTrue(wrote_event.is_set())
self.assertFalse(post_sleep_event.is_set())
proc.kill()
proc.join()
def test_process_stops_talking_back(self):
post_sleep_event = multiprocessing.Event()
opened_event = multiprocessing.Event()
wrote_event = multiprocessing.Event()
with tempfile.TemporaryDirectory() as td:
reader = os.path.join(td, 'reader')
writer = os.path.join(td, 'writer')
os.mkfifo(reader, 0o666)
os.mkfifo(writer, 0o666)
def _the_process():
with open(writer, 'rb'):
with open(reader, 'wb') as out:
opened_event.set()
w = log_reader_test.LogTestExampleBuilder(opened_file=out)
w.write_header({
'features': [{
'name': 'times_called',
'port': 0,
'shape': [1],
'type': 'int64_t',
},],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float',
}
})
w.write_newline()
w.write_context_marker('hello')
w.write_observation_marker(0)
w.write_buff([1], ctypes.c_int64)
w.write_newline()
w.write_outcome_marker(0)
w.write_buff([3.14], ctypes.c_float)
w.write_newline()
out.flush()
wrote_event.set()
time.sleep(3600)
post_sleep_event.set()
proc = multiprocessing.Process(target=_the_process)
proc.start()
read_count = 0
with self.assertRaises(TimeoutError):
with env.interactive_session(
reader_name=reader, writer_name=writer,
timeout=10) as (read_pipe, _):
for obs in log_reader.read_log_from_file(read_pipe):
self.assertIsInstance(obs, log_reader.ObservationRecord)
read_count += 1
self.assertEqual(read_count, 1)
self.assertTrue(opened_event.is_set())
self.assertTrue(wrote_event.is_set())
self.assertFalse(post_sleep_event.is_set())
proc.kill()
proc.join()
class MLGOEnvironmentTest(tf.test.TestCase):
@mock.patch('subprocess.Popen')
def test_env(self, mock_popen):
mock_popen.side_effect = mock_interactive_clang
test_env = env.MLGOEnvironmentBase(
clang_path=_CLANG_PATH,
task_type=MockTask,
obs_spec={},
action_spec={},
)
for env_itr in range(3):
del env_itr
step = test_env.reset(_MOCK_MODULE)
self.assertEqual(step.step_type, env.StepType.FIRST)
for step_itr in range(_NUM_STEPS - 1):
del step_itr
step = test_env.step(np.array([1], dtype=np.int64))
self.assertEqual(step.step_type, env.StepType.MID)
step = test_env.step(np.array([1], dtype=np.int64))
self.assertEqual(step.step_type, env.StepType.LAST)
self.assertNotEqual(test_env._iclang, test_env._clang) # pylint: disable=protected-access
@mock.patch('subprocess.Popen')
def test_env_interactive_only(self, mock_popen):
mock_popen.side_effect = mock_interactive_clang
test_env = env.MLGOEnvironmentBase(
clang_path=_CLANG_PATH,
task_type=MockTask,
obs_spec={},
action_spec={},
interactive_only=True,
)
for env_itr in range(3):
del env_itr
step = test_env.reset(_MOCK_MODULE)
self.assertEqual(step.step_type, env.StepType.FIRST)
for step_itr in range(_NUM_STEPS - 1):
del step_itr
step = test_env.step(np.array([1], dtype=np.int64))
self.assertEqual(step.step_type, env.StepType.MID)
step = test_env.step(np.array([1], dtype=np.int64))
self.assertEqual(step.step_type, env.StepType.LAST)
self.assertEqual(test_env._iclang, test_env._clang) # pylint: disable=protected-access
if __name__ == '__main__':
tf.test.main()