Allow specifying the corpus description json file

I found it useful sometimes to have variations of the corpus description
- e.g. remove swaths of modules. This patch allows multiple descriptions
to co-exist under a corpus directory, and select the desired one
transparently.

Renamed the argument since "data path" is more of a directory than a
file, "location" covers both.
diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py
index 37b7f0d..e3dcfd2 100644
--- a/compiler_opt/rl/corpus.py
+++ b/compiler_opt/rl/corpus.py
@@ -32,6 +32,8 @@
 # command line, where all the flags reference existing, local files.
 FullyQualifiedCmdLine = Tuple[str, ...]
 
+DEFAULT_CORPUS_DESCRIPTION_FILENAME = 'corpus_description.json'
+
 
 def _apply_cmdline_filters(
     orig_options: Tuple[str, ...],
@@ -185,7 +187,7 @@
 class Corpus:
   """Represents a corpus.
 
-  A corpus is created from a corpus_description.json file, produced by
+  A corpus is created from a corpus description json file, produced by
   extract_ir.py (for example).
 
   To use the corpus:
@@ -225,7 +227,7 @@
 
   def __init__(self,
                *,
-               data_path: str,
+               location: str,
                module_filter: Optional[Callable[[str], bool]] = None,
                additional_flags: Tuple[str, ...] = (),
                delete_flags: Tuple[str, ...] = (),
@@ -238,7 +240,10 @@
     output) and validated.
 
     Args:
-      data_path: corpus directory.
+      location: either the path to the corpus description json file in a corpus
+        directory, or just the corpus directory, case in which
+        `DEFAULT_CORPUS_DESCRIPTION_FILENAME` will be assumed to be available
+        in that directory and used as corpus description.
       additional_flags: list of flags to append to the command line
       delete_flags: list of flags to remove (both `-flag=<value` and
         `-flag <value>` are supported).
@@ -251,18 +256,24 @@
       module_filter: a regular expression used to filter 'in' modules with names
         matching it. None to include everything.
     """
-    self._base_dir = data_path
+    if tf.io.gfile.isdir(location):
+      self._base_dir = location
+      corpus_description = os.path.join(location,
+                                        DEFAULT_CORPUS_DESCRIPTION_FILENAME)
+    else:
+      self._base_dir = os.path.dirname(location)
+      corpus_description = location
+
     self._sampler = sampler
     # TODO: (b/233935329) Per-corpus *fdo profile paths can be read into
     # {additional|delete}_flags here
-    with tf.io.gfile.GFile(
-        os.path.join(data_path, 'corpus_description.json'), 'r') as f:
+    with tf.io.gfile.GFile(corpus_description, 'r') as f:
       corpus_description: Dict[str, Any] = json.load(f)
 
     module_paths = corpus_description['modules']
     if len(module_paths) == 0:
       raise ValueError(
-          f'{data_path}\'s corpus_description contains no modules.')
+          f'{corpus_description} corpus description contains no modules.')
 
     has_thinlto: bool = corpus_description['has_thinlto']
 
@@ -273,7 +284,7 @@
       if corpus_description[
           'global_command_override'] == constant.UNSPECIFIED_OVERRIDE:
         raise ValueError(
-            'global_command_override in corpus_description.json not filled.')
+            f'global_command_override in {corpus_description} not filled.')
       cmd_override = tuple(corpus_description['global_command_override'])
       if len(additional_flags) > 0:
         logging.warning(
@@ -314,7 +325,8 @@
       if cmd_override_was_specified:
         ret = cmd_override
       else:
-        with tf.io.gfile.GFile(os.path.join(data_path, name + '.cmd')) as f:
+        with tf.io.gfile.GFile(os.path.join(self._base_dir,
+                                            name + '.cmd')) as f:
           ret = tuple(f.read().replace(r'{', r'{{').replace(r'}',
                                                             r'}}').split('\0'))
           # The options read from a .cmd file must be run with -cc1
@@ -331,8 +343,8 @@
       contents = tp.map(
           lambda name: ModuleSpec(
               name=name,
-              size=tf.io.gfile.GFile(os.path.join(data_path, name + '.bc')).
-              size(),
+              size=tf.io.gfile.GFile(
+                  os.path.join(self._base_dir, name + '.bc')).size(),
               command_line=get_cmdline(name),
               has_thinlto=has_thinlto), module_paths)
     self._module_specs = tuple(
@@ -405,6 +417,6 @@
   if cmdline_is_override:
     corpus_description['global_command_override'] = cmdline
   with tf.io.gfile.GFile(
-      os.path.join(location, 'corpus_description.json'), 'w') as f:
+      os.path.join(location, DEFAULT_CORPUS_DESCRIPTION_FILENAME), 'w') as f:
     f.write(json.dumps(corpus_description))
-  return Corpus(data_path=location, **kwargs)
+  return Corpus(location=location, **kwargs)
diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py
index 9cabf4d..27ff674 100644
--- a/compiler_opt/rl/corpus_test.py
+++ b/compiler_opt/rl/corpus_test.py
@@ -105,6 +105,17 @@
         has_thinlto=False),))
     self.assertEqual(len(cps), 1)
 
+  def test_specific_path(self):
+    basedir = self.create_tempdir()
+    cps = corpus.create_corpus_for_testing(
+        location=basedir, elements=[corpus.ModuleSpec(name='1', size=1)])
+    new_corpus_desc = os.path.join(basedir, 'hi.json')
+    tf.io.gfile.rename(
+        os.path.join(basedir, corpus.DEFAULT_CORPUS_DESCRIPTION_FILENAME),
+        new_corpus_desc)
+    cps2 = corpus.Corpus(location=new_corpus_desc)
+    self.assertTupleEqual(cps.module_specs, cps2.module_specs)
+
   def test_invalid_args(self):
     with self.assertRaises(
         ValueError, msg='-cc1 flag not present in .cmd file'):
diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py
index 5d819c0..8837cbb 100644
--- a/compiler_opt/rl/train_locally.py
+++ b/compiler_opt/rl/train_locally.py
@@ -45,8 +45,10 @@
 
 flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                     'Root directory for writing logs/summaries/checkpoints.')
-flags.DEFINE_string('data_path', None,
-                    'Path to directory containing the corpus.')
+flags.DEFINE_string(
+    'data_path', None,
+    'Path to directory containing the corpus, or specific corpus description '
+    'json file.')
 flags.DEFINE_integer(
     'num_workers', None,
     'Number of parallel data collection workers. `None` for max available')
@@ -103,7 +105,7 @@
 
   logging.info('Loading module specs from corpus at %s.', FLAGS.data_path)
   cps = corpus.Corpus(
-      data_path=FLAGS.data_path,
+      location=FLAGS.data_path,
       additional_flags=problem_config.flags_to_add(),
       delete_flags=problem_config.flags_to_delete(),
       replace_flags=problem_config.flags_to_replace())
diff --git a/compiler_opt/tools/generate_default_trace.py b/compiler_opt/tools/generate_default_trace.py
index 3124480..b37e8dd 100644
--- a/compiler_opt/tools/generate_default_trace.py
+++ b/compiler_opt/tools/generate_default_trace.py
@@ -37,8 +37,10 @@
 # see https://bugs.python.org/issue33315 - we do need these types, but must
 # currently use them as string annotations
 
-_DATA_PATH = flags.DEFINE_string('data_path', None,
-                                 'Path to folder containing IR files.')
+_DATA_PATH = flags.DEFINE_string(
+    'data_path', None,
+    'Path to directory containing IR files, or path to description json file '
+    'under such a directory.')
 _POLICY_PATH = flags.DEFINE_string(
     'policy_path', '', 'Path to the policy to generate trace with.')
 _OUTPUT_PATH = flags.DEFINE_string(
@@ -145,7 +147,7 @@
       _MODULE_FILTER.value) if _MODULE_FILTER.value else None
 
   cps = corpus.Corpus(
-      data_path=_DATA_PATH.value,
+      location=_DATA_PATH.value,
       module_filter=lambda name: True
       if not module_filter else module_filter.match(name),
       additional_flags=config.flags_to_add(),