Merge pull request #123 from EricRahm/remove_subproc

Call parsing and generation functions directly in `embossc`
diff --git a/.bazelrc b/.bazelrc
index 85399b9..f3470b4 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -1,3 +1,5 @@
+# Emboss doesn't (yet) support bzlmod.
+common --noenable_bzlmod
 # Emboss (at least notionally) supports C++11 until (at least) 2027.  However,
 # Googletest requires C++14, which means that all of the Emboss unit tests
 # require C++14 to run.
diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py
index 5f67d7c..b43ca0e 100644
--- a/compiler/front_end/constraints.py
+++ b/compiler/front_end/constraints.py
@@ -537,6 +537,8 @@
   # errors are just returned, rather than appended to a shared list.
   errors += _integer_bounds_errors_for_expression(expression, source_file_name)
 
+def _attribute_in_attribute_action(a):
+  return {"in_attribute": a}
 
 def check_constraints(ir):
   """Checks miscellaneous validity constraints in ir.
@@ -597,7 +599,7 @@
       parameters={"errors": errors})
   traverse_ir.fast_traverse_ir_top_down(
       ir, [ir_pb2.Expression], _check_bounds_on_runtime_integer_expressions,
-      incidental_actions={ir_pb2.Attribute: lambda a: {"in_attribute": a}},
+      incidental_actions={ir_pb2.Attribute: _attribute_in_attribute_action},
       skip_descendants_of={ir_pb2.EnumValue, ir_pb2.Expression},
       parameters={"errors": errors, "in_attribute": None})
   traverse_ir.fast_traverse_ir_top_down(
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 54bcdcc..f4fb581 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -468,6 +468,8 @@
       "visible_scopes": (field.name.canonical_name,) + visible_scopes,
   }
 
+def _module_source_from_table_action(m, table):
+  return {"module": table[m.source_file_name]}
 
 def _resolve_symbols_from_table(ir, table):
   """Resolves all references in the given IR, given the constructed table."""
@@ -477,7 +479,7 @@
   traverse_ir.fast_traverse_ir_top_down(
       ir, [ir_pb2.Import], _add_import_to_scope,
       incidental_actions={
-          ir_pb2.Module: lambda m, table: {"module": table[m.source_file_name]},
+          ir_pb2.Module: _module_source_from_table_action,
       },
       parameters={"errors": errors, "table": table})
   if errors:
@@ -490,7 +492,6 @@
       incidental_actions={
           ir_pb2.TypeDefinition: _set_visible_scopes_for_type_definition,
           ir_pb2.Module: _set_visible_scopes_for_module,
-          ir_pb2.Field: lambda f: {"field": f},
           ir_pb2.Attribute: _set_visible_scopes_for_attribute,
       },
       parameters={"table": table, "errors": errors, "field": None})
@@ -500,7 +501,6 @@
       incidental_actions={
           ir_pb2.TypeDefinition: _set_visible_scopes_for_type_definition,
           ir_pb2.Module: _set_visible_scopes_for_module,
-          ir_pb2.Field: lambda f: {"field": f},
           ir_pb2.Attribute: _set_visible_scopes_for_attribute,
       },
       parameters={"table": table, "errors": errors, "field": None})
@@ -515,7 +515,6 @@
       incidental_actions={
           ir_pb2.TypeDefinition: _set_visible_scopes_for_type_definition,
           ir_pb2.Module: _set_visible_scopes_for_module,
-          ir_pb2.Field: lambda f: {"field": f},
           ir_pb2.Attribute: _set_visible_scopes_for_attribute,
       },
       parameters={"errors": errors, "field": None})
diff --git a/compiler/util/traverse_ir.py b/compiler/util/traverse_ir.py
index fc04ba8..3bd95c3 100644
--- a/compiler/util/traverse_ir.py
+++ b/compiler/util/traverse_ir.py
@@ -17,27 +17,98 @@
 import inspect
 
 from compiler.util import ir_pb2
+from compiler.util import simple_memoizer
+
+
+class _FunctionCaller:
+  """Provides a template for setting up a generic call to a function.
+
+  The function parameters are inspected at run-time to build up a set of valid
+  and required arguments. When invoking the function unneccessary parameters
+  will be trimmed out. If arguments are missing an assertion will be triggered.
+
+  This is currently limited to functions that have at least one positional
+  parameter.
+
+  Example usage:
+  ```
+  def func_1(a, b, c=2): pass
+  def func_2(a, d): pass
+  caller_1 = _FunctionCaller(func_1)
+  caller_2 = _FunctionCaller(func_2)
+  generic_params = {"b": 2, "c": 3, "d": 4}
+
+  # Equivalent of: func_1(a, b=2, c=3)
+  caller_1.invoke(a, generic_params)
+
+  # Equivalent of: func_2(a, d=4)
+  caller_2.invoke(a, generic_params)
+  """
+
+  def __init__(self, function):
+    self.function = function
+    self.needs_filtering = True
+    self.valid_arg_names = set()
+    self.required_arg_names = set()
+
+    argspec = inspect.getfullargspec(function)
+    if argspec.varkw:
+      # If the function accepts a kwargs parameter, then it will accept all
+      # arguments.
+      # Note: this isn't technically true if one of the keyword arguments has the
+      # same name as one of the positional arguments.
+      self.needs_filtering = False
+    else:
+      # argspec.args is a list of all parameter names excluding keyword only
+      # args. The first element is our required positional_arg and should be
+      # ignored.
+      args = argspec.args[1:]
+      self.valid_arg_names.update(args)
+
+      # args.kwonlyargs gives us the list of keyword only args which are
+      # also valid.
+      self.valid_arg_names.update(argspec.kwonlyargs)
+
+      # Required args are positional arguments that don't have defaults.
+      # Keyword only args are always optional and can be ignored. Args with
+      # defaults are the last elements of the argsepec.args list and should
+      # be ignored.
+      if argspec.defaults:
+        # Trim the arguments with defaults.
+        args = args[: -len(argspec.defaults)]
+      self.required_arg_names.update(args)
+
+  def invoke(self, positional_arg, keyword_args):
+    """Invokes the function with the given args."""
+    if self.needs_filtering:
+      # Trim to just recognized args.
+      matched_args = {
+          k: v for k, v in keyword_args.items() if k in self.valid_arg_names
+      }
+      # Check if any required args are missing.
+      missing_args = self.required_arg_names.difference(matched_args.keys())
+      assert not missing_args, (
+          f"Attempting to call '{self.function.__name__}'; "
+          f"missing {missing_args} (have {set(keyword_args.keys())})"
+      )
+      keyword_args = matched_args
+
+    return self.function(positional_arg, **keyword_args)
+
+
+@simple_memoizer.memoize
+def _memoized_caller(function):
+  default_lambda_name = (lambda: None).__name__
+  assert (
+      callable(function) and not function.__name__ == default_lambda_name
+  ), "For performance reasons actions must be defined as static functions"
+  return _FunctionCaller(function)
 
 
 def _call_with_optional_args(function, positional_arg, keyword_args):
   """Calls function with whatever keyword_args it will accept."""
-  argspec = inspect.getfullargspec(function)
-  if argspec.varkw:
-    # If the function accepts a kwargs parameter, then it will accept all
-    # arguments.
-    # Note: this isn't technically true if one of the keyword arguments has the
-    # same name as one of the positional arguments.
-    return function(positional_arg, **keyword_args)
-  else:
-    ok_arguments = {}
-    for name in keyword_args:
-      if name in argspec.args[1:] or name in argspec.kwonlyargs:
-        ok_arguments[name] = keyword_args[name]
-    for name in argspec.args[1:len(argspec.args) - len(argspec.defaults or [])]:
-      assert name in ok_arguments, (
-          "Attempting to call '{}'; missing '{}' (have '{!r}')".format(
-              function.__name__, name, list(keyword_args.keys())))
-    return function(positional_arg, **ok_arguments)
+  caller = _memoized_caller(function)
+  return caller.invoke(positional_arg, keyword_args)
 
 
 def _fast_traverse_proto_top_down(proto, incidental_actions, pattern,
@@ -181,6 +252,17 @@
 
 _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET = _fields_to_scan_by_current_and_target()
 
+def _emboss_ir_action(ir):
+  return {"ir": ir}
+
+def _module_action(m):
+  return {"source_file_name": m.source_file_name}
+
+def _type_definition_action(t):
+  return {"type_definition": t}
+
+def _field_action(f):
+  return {"field": f}
 
 def fast_traverse_ir_top_down(ir, pattern, action, incidental_actions=None,
                               skip_descendants_of=(), parameters=None):
@@ -269,10 +351,10 @@
     None
   """
   all_incidental_actions = {
-      ir_pb2.EmbossIr: [lambda ir: {"ir": ir}],
-      ir_pb2.Module: [lambda m: {"source_file_name": m.source_file_name}],
-      ir_pb2.TypeDefinition: [lambda t: {"type_definition": t}],
-      ir_pb2.Field: [lambda f: {"field": f}],
+      ir_pb2.EmbossIr: [_emboss_ir_action],
+      ir_pb2.Module: [_module_action],
+      ir_pb2.TypeDefinition: [_type_definition_action],
+      ir_pb2.Field: [_field_action],
   }
   if incidental_actions:
     for key, incidental_action in incidental_actions.items():
diff --git a/runtime/cpp/test/emboss_memory_util_test.cc b/runtime/cpp/test/emboss_memory_util_test.cc
index 3cfca8b..c4f9ed0 100644
--- a/runtime/cpp/test/emboss_memory_util_test.cc
+++ b/runtime/cpp/test/emboss_memory_util_test.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <string>
 #if __cplusplus >= 201703L
 #include <string_view>
 #endif  // __cplusplus >= 201703L
@@ -174,13 +175,13 @@
 // std::basic_string<> with non-default trailing template parameters.
 template <class T>
 struct NonstandardAllocator {
-  using value_type = typename ::std::allocator<T>::value_type;
-  using pointer = typename ::std::allocator<T>::pointer;
-  using const_pointer = typename ::std::allocator<T>::const_pointer;
-  using reference = typename ::std::allocator<T>::reference;
-  using const_reference = typename ::std::allocator<T>::const_reference;
-  using size_type = typename ::std::allocator<T>::size_type;
-  using difference_type = typename ::std::allocator<T>::difference_type;
+  using value_type = typename ::std::allocator_traits<::std::allocator<T>>::value_type;
+  using pointer = typename ::std::allocator_traits<::std::allocator<T>>::pointer;
+  using const_pointer = typename ::std::allocator_traits<::std::allocator<T>>::const_pointer;
+  using reference = typename ::std::allocator<T>::value_type &;
+  using const_reference = const typename ::std::allocator_traits<::std::allocator<T>>::value_type &;
+  using size_type = typename ::std::allocator_traits<::std::allocator<T>>::size_type;
+  using difference_type = typename ::std::allocator_traits<::std::allocator<T>>::difference_type;
 
   template <class U>
   struct rebind {
@@ -221,7 +222,7 @@
 typedef ::testing::Types<
     /**/ ::std::vector<char>, ::std::array<char, 8>,
     ::std::vector<unsigned char>, ::std::vector<signed char>, ::std::string,
-    ::std::basic_string<signed char>, ::std::basic_string<unsigned char>,
+    ::std::basic_string<char>,
     ::std::vector<unsigned char, NonstandardAllocator<unsigned char>>,
     ::std::basic_string<char, ::std::char_traits<char>,
                         NonstandardAllocator<char>>>