Corpus: let module_filter be a function
That makes it easier to filter arbitrary modules (e.g. from an
allow/deny list)
diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py
index 4d65146..4cb83a8 100644
--- a/compiler_opt/rl/corpus.py
+++ b/compiler_opt/rl/corpus.py
@@ -17,11 +17,10 @@
import concurrent.futures
import math
import random
-import re
from absl import logging
from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import json
import os
@@ -227,7 +226,7 @@
def __init__(self,
*,
data_path: str,
- module_filter: Optional[re.Pattern] = None,
+ module_filter: Optional[Callable[[str], bool]] = None,
additional_flags: Tuple[str, ...] = (),
delete_flags: Tuple[str, ...] = (),
replace_flags: Optional[Dict[str, str]] = None,
@@ -309,9 +308,7 @@
raise ValueError('do not use add/delete flags to replace')
if module_filter:
- module_paths = [
- name for name in module_paths if module_filter.match(name)
- ]
+ module_paths = [name for name in module_paths if module_filter(name)]
def get_cmdline(name: str):
if cmd_override_was_specified:
diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py
index d15f0f8..083d3de 100644
--- a/compiler_opt/rl/corpus_test.py
+++ b/compiler_opt/rl/corpus_test.py
@@ -230,7 +230,7 @@
corpus.ModuleSpec(name='largest', size=500),
corpus.ModuleSpec(name='small', size=100)
],
- module_filter=re.compile(r'.+l'))
+ module_filter=lambda name: re.compile(r'.+l').match(name))
sample = cps.sample(999, sort=True)
self.assertLen(sample, 3)
self.assertEqual(sample[0].name, 'middle')
diff --git a/compiler_opt/tools/generate_default_trace.py b/compiler_opt/tools/generate_default_trace.py
index 72fa1f8..3124480 100644
--- a/compiler_opt/tools/generate_default_trace.py
+++ b/compiler_opt/tools/generate_default_trace.py
@@ -146,7 +146,8 @@
cps = corpus.Corpus(
data_path=_DATA_PATH.value,
- module_filter=module_filter,
+ module_filter=lambda name: True
+ if not module_filter else module_filter.match(name),
additional_flags=config.flags_to_add(),
delete_flags=config.flags_to_delete(),
replace_flags=config.flags_to_replace())