Corpus: let module_filter be a function

That makes it easier to filter arbitrary modules (e.g. from an
allow/deny list)
diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py
index 4d65146..4cb83a8 100644
--- a/compiler_opt/rl/corpus.py
+++ b/compiler_opt/rl/corpus.py
@@ -17,11 +17,10 @@
 import concurrent.futures
 import math
 import random
-import re
 
 from absl import logging
 from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import json
 import os
@@ -227,7 +226,7 @@
   def __init__(self,
                *,
                data_path: str,
-               module_filter: Optional[re.Pattern] = None,
+               module_filter: Optional[Callable[[str], bool]] = None,
                additional_flags: Tuple[str, ...] = (),
                delete_flags: Tuple[str, ...] = (),
                replace_flags: Optional[Dict[str, str]] = None,
@@ -309,9 +308,7 @@
       raise ValueError('do not use add/delete flags to replace')
 
     if module_filter:
-      module_paths = [
-          name for name in module_paths if module_filter.match(name)
-      ]
+      module_paths = [name for name in module_paths if module_filter(name)]
 
     def get_cmdline(name: str):
       if cmd_override_was_specified:
diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py
index d15f0f8..083d3de 100644
--- a/compiler_opt/rl/corpus_test.py
+++ b/compiler_opt/rl/corpus_test.py
@@ -230,7 +230,7 @@
             corpus.ModuleSpec(name='largest', size=500),
             corpus.ModuleSpec(name='small', size=100)
         ],
-        module_filter=re.compile(r'.+l'))
+        module_filter=lambda name: re.compile(r'.+l').match(name))
     sample = cps.sample(999, sort=True)
     self.assertLen(sample, 3)
     self.assertEqual(sample[0].name, 'middle')
diff --git a/compiler_opt/tools/generate_default_trace.py b/compiler_opt/tools/generate_default_trace.py
index 72fa1f8..3124480 100644
--- a/compiler_opt/tools/generate_default_trace.py
+++ b/compiler_opt/tools/generate_default_trace.py
@@ -146,7 +146,8 @@
 
   cps = corpus.Corpus(
       data_path=_DATA_PATH.value,
-      module_filter=module_filter,
+      module_filter=lambda name: True
+      if not module_filter else module_filter.match(name),
       additional_flags=config.flags_to_add(),
       delete_flags=config.flags_to_delete(),
       replace_flags=config.flags_to_replace())