Allow specifying the corpus description json file
I found it useful sometimes to have variations of the corpus description
- e.g. remove swaths of modules. This patch allows multiple descriptions
to co-exist under a corpus directory, and select the desired one
transparently.
Renamed the argument since "data path" is more of a directory than a
file, "location" covers both.
diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py
index 37b7f0d..e3dcfd2 100644
--- a/compiler_opt/rl/corpus.py
+++ b/compiler_opt/rl/corpus.py
@@ -32,6 +32,8 @@
# command line, where all the flags reference existing, local files.
FullyQualifiedCmdLine = Tuple[str, ...]
+DEFAULT_CORPUS_DESCRIPTION_FILENAME = 'corpus_description.json'
+
def _apply_cmdline_filters(
orig_options: Tuple[str, ...],
@@ -185,7 +187,7 @@
class Corpus:
"""Represents a corpus.
- A corpus is created from a corpus_description.json file, produced by
+ A corpus is created from a corpus description json file, produced by
extract_ir.py (for example).
To use the corpus:
@@ -225,7 +227,7 @@
def __init__(self,
*,
- data_path: str,
+ location: str,
module_filter: Optional[Callable[[str], bool]] = None,
additional_flags: Tuple[str, ...] = (),
delete_flags: Tuple[str, ...] = (),
@@ -238,7 +240,10 @@
output) and validated.
Args:
- data_path: corpus directory.
+ location: either the path to the corpus description json file in a corpus
+ directory, or just the corpus directory, case in which
+ `DEFAULT_CORPUS_DESCRIPTION_FILENAME` will be assumed to be available
+ in that directory and used as corpus description.
additional_flags: list of flags to append to the command line
delete_flags: list of flags to remove (both `-flag=<value` and
`-flag <value>` are supported).
@@ -251,18 +256,24 @@
module_filter: a regular expression used to filter 'in' modules with names
matching it. None to include everything.
"""
- self._base_dir = data_path
+ if tf.io.gfile.isdir(location):
+ self._base_dir = location
+ corpus_description = os.path.join(location,
+ DEFAULT_CORPUS_DESCRIPTION_FILENAME)
+ else:
+ self._base_dir = os.path.dirname(location)
+ corpus_description = location
+
self._sampler = sampler
# TODO: (b/233935329) Per-corpus *fdo profile paths can be read into
# {additional|delete}_flags here
- with tf.io.gfile.GFile(
- os.path.join(data_path, 'corpus_description.json'), 'r') as f:
+ with tf.io.gfile.GFile(corpus_description, 'r') as f:
corpus_description: Dict[str, Any] = json.load(f)
module_paths = corpus_description['modules']
if len(module_paths) == 0:
raise ValueError(
- f'{data_path}\'s corpus_description contains no modules.')
+ f'{corpus_description} corpus description contains no modules.')
has_thinlto: bool = corpus_description['has_thinlto']
@@ -273,7 +284,7 @@
if corpus_description[
'global_command_override'] == constant.UNSPECIFIED_OVERRIDE:
raise ValueError(
- 'global_command_override in corpus_description.json not filled.')
+ f'global_command_override in {corpus_description} not filled.')
cmd_override = tuple(corpus_description['global_command_override'])
if len(additional_flags) > 0:
logging.warning(
@@ -314,7 +325,8 @@
if cmd_override_was_specified:
ret = cmd_override
else:
- with tf.io.gfile.GFile(os.path.join(data_path, name + '.cmd')) as f:
+ with tf.io.gfile.GFile(os.path.join(self._base_dir,
+ name + '.cmd')) as f:
ret = tuple(f.read().replace(r'{', r'{{').replace(r'}',
r'}}').split('\0'))
# The options read from a .cmd file must be run with -cc1
@@ -331,8 +343,8 @@
contents = tp.map(
lambda name: ModuleSpec(
name=name,
- size=tf.io.gfile.GFile(os.path.join(data_path, name + '.bc')).
- size(),
+ size=tf.io.gfile.GFile(
+ os.path.join(self._base_dir, name + '.bc')).size(),
command_line=get_cmdline(name),
has_thinlto=has_thinlto), module_paths)
self._module_specs = tuple(
@@ -405,6 +417,6 @@
if cmdline_is_override:
corpus_description['global_command_override'] = cmdline
with tf.io.gfile.GFile(
- os.path.join(location, 'corpus_description.json'), 'w') as f:
+ os.path.join(location, DEFAULT_CORPUS_DESCRIPTION_FILENAME), 'w') as f:
f.write(json.dumps(corpus_description))
- return Corpus(data_path=location, **kwargs)
+ return Corpus(location=location, **kwargs)
diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py
index 9cabf4d..27ff674 100644
--- a/compiler_opt/rl/corpus_test.py
+++ b/compiler_opt/rl/corpus_test.py
@@ -105,6 +105,17 @@
has_thinlto=False),))
self.assertEqual(len(cps), 1)
+ def test_specific_path(self):
+ basedir = self.create_tempdir()
+ cps = corpus.create_corpus_for_testing(
+ location=basedir, elements=[corpus.ModuleSpec(name='1', size=1)])
+ new_corpus_desc = os.path.join(basedir, 'hi.json')
+ tf.io.gfile.rename(
+ os.path.join(basedir, corpus.DEFAULT_CORPUS_DESCRIPTION_FILENAME),
+ new_corpus_desc)
+ cps2 = corpus.Corpus(location=new_corpus_desc)
+ self.assertTupleEqual(cps.module_specs, cps2.module_specs)
+
def test_invalid_args(self):
with self.assertRaises(
ValueError, msg='-cc1 flag not present in .cmd file'):
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 5d819c0..8837cbb 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -45,8 +45,10 @@
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
-flags.DEFINE_string('data_path', None,
- 'Path to directory containing the corpus.')
+flags.DEFINE_string(
+ 'data_path', None,
+ 'Path to directory containing the corpus, or specific corpus description '
+ 'json file.')
flags.DEFINE_integer(
'num_workers', None,
'Number of parallel data collection workers. `None` for max available')
@@ -103,7 +105,7 @@
logging.info('Loading module specs from corpus at %s.', FLAGS.data_path)
cps = corpus.Corpus(
- data_path=FLAGS.data_path,
+ location=FLAGS.data_path,
additional_flags=problem_config.flags_to_add(),
delete_flags=problem_config.flags_to_delete(),
replace_flags=problem_config.flags_to_replace())
diff --git a/compiler_opt/tools/generate_default_trace.py b/compiler_opt/tools/generate_default_trace.py
index 3124480..b37e8dd 100644
--- a/compiler_opt/tools/generate_default_trace.py
+++ b/compiler_opt/tools/generate_default_trace.py
@@ -37,8 +37,10 @@
# see https://bugs.python.org/issue33315 - we do need these types, but must
# currently use them as string annotations
-_DATA_PATH = flags.DEFINE_string('data_path', None,
- 'Path to folder containing IR files.')
+_DATA_PATH = flags.DEFINE_string(
+ 'data_path', None,
+ 'Path to directory containing IR files, or path to description json file '
+ 'under such a directory.')
_POLICY_PATH = flags.DEFINE_string(
'policy_path', '', 'Path to the policy to generate trace with.')
_OUTPUT_PATH = flags.DEFINE_string(
@@ -145,7 +147,7 @@
_MODULE_FILTER.value) if _MODULE_FILTER.value else None
cps = corpus.Corpus(
- data_path=_DATA_PATH.value,
+ location=_DATA_PATH.value,
module_filter=lambda name: True
if not module_filter else module_filter.match(name),
additional_flags=config.flags_to_add(),