Merge pull request #143 from EricRahm/add_builder

Add `ir_data_utils.builder`
diff --git a/compiler/back_end/cpp/header_generator_test.py b/compiler/back_end/cpp/header_generator_test.py
index daea7ea..d58f798 100644
--- a/compiler/back_end/cpp/header_generator_test.py
+++ b/compiler/back_end/cpp/header_generator_test.py
@@ -19,6 +19,7 @@
 from compiler.front_end import glue
 from compiler.util import error
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import test_util
 
 def _make_ir_from_emb(emb_text, name="m.emb"):
@@ -95,6 +96,7 @@
     attr = ir.module[0].type[0].attribute[0]
 
     bad_case_source_location = ir_data.Location()
+    bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
     bad_case_source_location.CopyFrom(attr.value.source_location)
     # Location of SHORTY_CASE in the attribute line.
     bad_case_source_location.start.column = 30
@@ -114,6 +116,7 @@
     attr = ir.module[0].type[0].attribute[0]
 
     bad_case_source_location = ir_data.Location()
+    bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
     bad_case_source_location.CopyFrom(attr.value.source_location)
     # Location of bad_CASE in the attribute line.
     bad_case_source_location.start.column = 43
@@ -133,6 +136,7 @@
     attr = ir.module[0].type[0].attribute[0]
 
     bad_case_source_location = ir_data.Location()
+    bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
     bad_case_source_location.CopyFrom(attr.value.source_location)
     # Location of BAD_case in the attribute line.
     bad_case_source_location.start.column = 55
@@ -152,6 +156,7 @@
     attr = ir.module[0].type[0].attribute[0]
 
     bad_case_source_location = ir_data.Location()
+    bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
     bad_case_source_location.CopyFrom(attr.value.source_location)
     # Location of the second SHOUTY_CASE in the attribute line.
     bad_case_source_location.start.column = 43
@@ -172,6 +177,7 @@
     attr = ir.module[0].type[0].attribute[0]
 
     bad_case_source_location = ir_data.Location()
+    bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
     bad_case_source_location.CopyFrom(attr.value.source_location)
     # Location of excess comma.
     bad_case_source_location.start.column = 42
diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py
index e9b423b..cd6a851 100644
--- a/compiler/front_end/expression_bounds.py
+++ b/compiler/front_end/expression_bounds.py
@@ -19,6 +19,7 @@
 import operator
 
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import ir_util
 from compiler.util import traverse_ir
 
@@ -65,6 +66,7 @@
 def _compute_constant_value_of_constant_reference(expression, ir):
   referred_object = ir_util.find_object(
       expression.constant_reference.canonical_name, ir)
+  expression = ir_data_utils.builder(expression)
   if isinstance(referred_object, ir_data.EnumValue):
     compute_constraints_of_expression(referred_object.value, ir)
     assert ir_util.is_constant(referred_object.value)
@@ -111,7 +113,7 @@
   field_path = expression.function.args[0].field_reference.path[-1]
   field = ir_util.find_object(field_path, ir)
   compute_constraints_of_expression(field.existence_condition, ir)
-  expression.type.CopyFrom(field.existence_condition.type)
+  ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type)
 
 
 def _compute_constraints_of_field_reference(expression, ir):
@@ -122,7 +124,7 @@
     # References to virtual fields should have the virtual field's constraints
     # copied over.
     compute_constraints_of_expression(field.read_transform, ir)
-    expression.type.CopyFrom(field.read_transform.type)
+    ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
     return
   # Non-virtual non-integer fields do not (yet) have constraints.
   if expression.type.WhichOneof("type") == "integer":
@@ -633,6 +635,7 @@
 def _compute_constraints_of_choice_operator(expression):
   """Computes the constraints of a choice operation '?:'."""
   condition, if_true, if_false = expression.function.args
+  expression = ir_data_utils.builder(expression)
   if condition.type.boolean.HasField("value"):
     # The generated expressions for $size_in_bits and $size_in_bytes look like
     #
diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py
index ed27a1e..163bf6e 100644
--- a/compiler/front_end/module_ir.py
+++ b/compiler/front_end/module_ir.py
@@ -26,6 +26,7 @@
 import sys
 
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import name_conversion
 from compiler.util import parser_types
 
@@ -793,9 +794,9 @@
 def _structure(struct, name, parameters, colon, comment, newline, struct_body):
   """Composes the top-level IR for an Emboss structure."""
   del colon, comment, newline  # Unused.
-  struct_body.structure.source_location.start.CopyFrom(
+  ir_data_utils.builder(struct_body.structure).source_location.start.CopyFrom(
       struct.source_location.start)
-  struct_body.structure.source_location.end.CopyFrom(
+  ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom(
       struct_body.source_location.end)
   struct_body.name.CopyFrom(name)
   if parameters.list:
@@ -879,11 +880,11 @@
           '    unconditional-anonymous-bits-field anonymous-bits-field-block')
 def _unconditional_block_plus_field_block(field, block):
   """Prepends an unconditional field to block."""
-  field.field.existence_condition.source_location.CopyFrom(
+  ir_data_utils.builder(field.field).existence_condition.source_location.CopyFrom(
       field.source_location)
-  field.field.existence_condition.boolean_constant.source_location.CopyFrom(
+  ir_data_utils.builder(field.field).existence_condition.boolean_constant.source_location.CopyFrom(
       field.source_location)
-  field.field.existence_condition.boolean_constant.value = True
+  ir_data_utils.builder(field.field).existence_condition.boolean_constant.value = True
   return _List([field] + block.list)
 
 
@@ -929,7 +930,7 @@
   """Applies an existence_condition to each element of fields."""
   del if_keyword, newline, colon, comment, indent, dedent  # Unused.
   for field in fields.list:
-    condition = field.field.existence_condition
+    condition = ir_data_utils.builder(field.field).existence_condition
     condition.CopyFrom(expression)
     condition.source_location.is_disjoint_from_parent = True
   return fields
@@ -967,11 +968,12 @@
            newline, field_body):
   """Constructs an ir_data.Field from the given components."""
   del comment  # Unused
-  field = ir_data.Field(location=location,
+  field_ir = ir_data.Field(location=location,
                        type=field_type,
                        name=name,
                        attribute=attributes.list,
                        documentation=doc.list)
+  field = ir_data_utils.builder(field_ir)
   if field_body.list:
     field.attribute.extend(field_body.list[0].attribute)
     field.documentation.extend(field_body.list[0].documentation)
@@ -982,7 +984,7 @@
     field.source_location.end.CopyFrom(field_body.source_location.end)
   else:
     field.source_location.end.CopyFrom(newline.source_location.end)
-  return _FieldWithType(field=field)
+  return _FieldWithType(field=field_ir)
 
 
 # A "virtual field" is:
@@ -996,7 +998,8 @@
 def _virtual_field(let, name, equals, value, comment, newline, field_body):
   """Constructs an ir_data.Field from the given components."""
   del equals, comment  # Unused
-  field = ir_data.Field(read_transform=value, name=name)
+  field_ir = ir_data.Field(read_transform=value, name=name)
+  field = ir_data_utils.builder(field_ir)
   if field_body.list:
     field.attribute.extend(field_body.list[0].attribute)
     field.documentation.extend(field_body.list[0].documentation)
@@ -1005,7 +1008,7 @@
     field.source_location.end.CopyFrom(field_body.source_location.end)
   else:
     field.source_location.end.CopyFrom(newline.source_location.end)
-  return _FieldWithType(field=field)
+  return _FieldWithType(field=field_ir)
 
 
 # An inline enum is:
@@ -1047,17 +1050,18 @@
 
 def _inline_type_field(location, name, abbreviation, body):
   """Shared implementation of _inline_enum_field and _anonymous_bit_field."""
-  field = ir_data.Field(location=location,
+  field_ir = ir_data.Field(location=location,
                        name=name,
                        attribute=body.attribute,
                        documentation=body.documentation)
+  field = ir_data_utils.builder(field_ir)
   # All attributes should be attached to the field, not the type definition: if
   # the user wants to use type attributes, they should create a separate type
   # definition and reference it.
   del body.attribute[:]
   type_name = ir_data.NameDefinition()
   type_name.CopyFrom(name)
-  type_name.name.text = name_conversion.snake_to_camel(type_name.name.text)
+  ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text)
   field.type.atomic_type.reference.source_name.extend([type_name.name])
   field.type.atomic_type.reference.source_location.CopyFrom(
       type_name.source_location)
@@ -1067,17 +1071,17 @@
   if abbreviation.list:
     field.abbreviation.CopyFrom(abbreviation.list[0])
   field.source_location.start.CopyFrom(location.source_location.start)
-  body.source_location.start.CopyFrom(location.source_location.start)
+  ir_data_utils.builder(body.source_location).start.CopyFrom(location.source_location.start)
   if body.HasField('enumeration'):
-    body.enumeration.source_location.CopyFrom(body.source_location)
+    ir_data_utils.builder(body.enumeration).source_location.CopyFrom(body.source_location)
   else:
     assert body.HasField('structure')
-    body.structure.source_location.CopyFrom(body.source_location)
-  body.name.CopyFrom(type_name)
+    ir_data_utils.builder(body.structure).source_location.CopyFrom(body.source_location)
+  ir_data_utils.builder(body).name.CopyFrom(type_name)
   field.source_location.end.CopyFrom(body.source_location.end)
   subtypes = [body] + list(body.subtype)
   del body.subtype[:]
-  return _FieldWithType(field=field, subtypes=subtypes)
+  return _FieldWithType(field=field_ir, subtypes=subtypes)
 
 
 @_handles('anonymous-bits-field-definition ->'
@@ -1113,11 +1117,11 @@
 @_handles('enum -> "enum" type-name ":" Comment? eol enum-body')
 def _enum(enum, name, colon, comment, newline, enum_body):
   del colon, comment, newline  # Unused.
-  enum_body.enumeration.source_location.start.CopyFrom(
+  ir_data_utils.builder(enum_body.enumeration).source_location.start.CopyFrom(
       enum.source_location.start)
-  enum_body.enumeration.source_location.end.CopyFrom(
+  ir_data_utils.builder(enum_body.enumeration).source_location.end.CopyFrom(
       enum_body.source_location.end)
-  enum_body.name.CopyFrom(name)
+  ir_data_utils.builder(enum_body).name.CopyFrom(name)
   return enum_body
 
 
@@ -1161,7 +1165,7 @@
 @_handles('external -> "external" type-name ":" Comment? eol external-body')
 def _external(external, name, colon, comment, newline, external_body):
   del colon, comment, newline  # Unused.
-  external_body.source_location.start.CopyFrom(external.source_location.start)
+  ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start)
   external_body.name.CopyFrom(name)
   return external_body
 
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 6f7c030..09675b8 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -22,6 +22,7 @@
 
 from compiler.util import error
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import ir_util
 from compiler.util import traverse_ir
 
@@ -133,7 +134,7 @@
   """Adds the given name_ir to scope and sets its canonical_name."""
   name = name_ir.name.text
   canonical_name = _nested_name(scope.canonical_name, name)
-  name_ir.canonical_name.CopyFrom(canonical_name)
+  ir_data_utils.builder(name_ir).canonical_name.CopyFrom(canonical_name)
   return _add_name_to_scope(name_ir.name, scope, canonical_name, visibility,
                             errors)
 
@@ -282,7 +283,7 @@
                                      visible_scopes, source_file_name, errors)
   if target is not None:
     assert not target.alias
-    reference.canonical_name.CopyFrom(target.canonical_name)
+    ir_data_utils.builder(reference).canonical_name.CopyFrom(target.canonical_name)
 
 
 def _find_target_of_reference(reference, table, current_scope, visible_scopes,
@@ -419,7 +420,7 @@
     member_name = ir_data.CanonicalName()
     member_name.CopyFrom(
         previous_field.type.atomic_type.reference.canonical_name)
-    member_name.object_path.extend([ref.source_name[0].text])
+    ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text])
     previous_field = ir_util.find_object_or_none(member_name, ir)
     if previous_field is None:
       errors.append(
@@ -427,7 +428,7 @@
                              ref.source_name[0].source_location,
                              ref.source_name[0].text))
       return
-    ref.canonical_name.CopyFrom(member_name)
+    ir_data_utils.builder(ref).canonical_name.CopyFrom(member_name)
     previous_reference = ref
 
 
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index 7f6aabb..ac79c47 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -18,6 +18,7 @@
 from compiler.util import error
 from compiler.util import expression_parser
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import ir_util
 from compiler.util import traverse_ir
 
@@ -27,7 +28,7 @@
   if not isinstance(proto, ir_data.Message):
     return
   if hasattr(proto, "source_location"):
-    proto.source_location.is_synthetic = True
+    ir_data_utils.builder(proto).source_location.is_synthetic = True
   for name, value in proto.raw_fields.items():
     if name != "source_location":
       if isinstance(value, ir_data.TypedScopedList):
@@ -112,7 +113,7 @@
       )
       new_existence_condition = ir_data.Expression()
       new_existence_condition.CopyFrom(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
-      existence_clauses = new_existence_condition.function.args
+      existence_clauses = ir_data_utils.builder(new_existence_condition).function.args
       existence_clauses[0].function.args[0].field_reference.CopyFrom(
           anonymous_field_reference)
       existence_clauses[1].function.args[0].field_reference.CopyFrom(
@@ -129,7 +130,7 @@
           existence_condition=new_existence_condition,
           name=subfield.name)
       if subfield.HasField("abbreviation"):
-        new_alias.abbreviation.CopyFrom(subfield.abbreviation)
+        ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation)
       _mark_as_synthetic(new_alias.existence_condition)
       _mark_as_synthetic(new_alias.read_transform)
       new_fields.append(new_alias)
@@ -196,6 +197,7 @@
       continue
     size_clause = ir_data.Expression()
     size_clause.CopyFrom(_SIZE_CLAUSE_SKELETON)
+    size_clause = ir_data_utils.builder(size_clause)
     # Copy the appropriate clauses into `existence_condition ? start + size : 0`
     size_clause.function.args[0].CopyFrom(field.existence_condition)
     size_clause.function.args[1].function.args[0].CopyFrom(field.location.start)
@@ -221,20 +223,21 @@
 _NEXT_KEYWORD_REPLACEMENT_EXPRESSION = expression_parser.parse("x + y")
 
 
-def _maybe_replace_next_keyword_in_expression(expression, last_location,
+def _maybe_replace_next_keyword_in_expression(expression_ir, last_location,
                                               source_file_name, errors):
-  if not expression.HasField("builtin_reference"):
+  if not expression_ir.HasField("builtin_reference"):
     return
-  if expression.builtin_reference.canonical_name.object_path[0] != "$next":
+  if expression_ir.builtin_reference.canonical_name.object_path[0] != "$next":
     return
   if not last_location:
     errors.append([
-      error.error(source_file_name, expression.source_location,
+      error.error(source_file_name, expression_ir.source_location,
                   "`$next` may not be used in the first physical field of a " +
                   "structure; perhaps you meant `0`?")
     ])
     return
-  original_location = expression.source_location
+  original_location = expression_ir.source_location
+  expression = ir_data_utils.builder(expression_ir)
   expression.CopyFrom(_NEXT_KEYWORD_REPLACEMENT_EXPRESSION)
   expression.function.args[0].CopyFrom(last_location.start)
   expression.function.args[1].CopyFrom(last_location.size)
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index 727989f..d8a226a 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -17,6 +17,7 @@
 from compiler.front_end import attributes
 from compiler.util import error
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import ir_util
 from compiler.util import traverse_ir
 
@@ -44,11 +45,11 @@
 
 
 def _annotate_as_integer(expression):
-  expression.type.integer.CopyFrom(ir_data.IntegerType())
+  ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
 
 
 def _annotate_as_boolean(expression):
-  expression.type.boolean.CopyFrom(ir_data.BooleanType())
+   ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
 
 
 def _type_check(expression, source_file_name, errors, type_oneof, type_name,
@@ -88,7 +89,7 @@
   referred_name = expression.constant_reference.canonical_name
   referred_object = ir_util.find_object(referred_name, ir)
   if isinstance(referred_object, ir_data.EnumValue):
-    expression.type.enumeration.name.CopyFrom(expression.constant_reference)
+    ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(expression.constant_reference)
     del expression.type.enumeration.name.canonical_name.object_path[-1]
   elif isinstance(referred_object, ir_data.Field):
     if not ir_util.field_is_virtual(referred_object):
@@ -102,7 +103,7 @@
       return
     _type_check_expression(referred_object.read_transform,
                            referred_name.module_file, ir, errors)
-    expression.type.CopyFrom(referred_object.read_transform.type)
+    ir_data_utils.builder(expression).type.CopyFrom(referred_object.read_transform.type)
   else:
     assert False, "Unexpected constant reference type."
 
@@ -189,10 +190,10 @@
   if ir_util.field_is_virtual(field):
     _type_check_expression(field.read_transform,
                            expression.field_reference.path[0], ir, errors)
-    expression.type.CopyFrom(field.read_transform.type)
+    ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
     return
   if not field.type.HasField("atomic_type"):
-    expression.type.opaque.CopyFrom(ir_data.OpaqueType())
+    ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType())
   else:
     _set_expression_type_from_physical_type_reference(
         expression, field.type.atomic_type.reference, ir)
@@ -232,7 +233,7 @@
   """Sets the type of an expression to match a physical type."""
   field_type = ir_util.find_object(type_reference, ir)
   assert field_type, "Field type should be non-None after name resolution."
-  expression.type.CopyFrom(
+  ir_data_utils.builder(expression).type.CopyFrom(
       unbounded_expression_type_for_physical_type(field_type))
 
 
@@ -323,7 +324,7 @@
   elif if_true.type.WhichOneof("type") == "boolean":
     _annotate_as_boolean(expression)
   elif if_true.type.WhichOneof("type") == "enumeration":
-    expression.type.enumeration.name.CopyFrom(if_true.type.enumeration.name)
+    ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(if_true.type.enumeration.name)
   else:
     assert False, "Unexpected type for if_true."
 
diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py
index ac58b34..db555bc 100644
--- a/compiler/front_end/write_inference.py
+++ b/compiler/front_end/write_inference.py
@@ -17,6 +17,7 @@
 from compiler.front_end import attributes
 from compiler.front_end import expression_bounds
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import ir_util
 from compiler.util import traverse_ir
 
@@ -216,9 +217,11 @@
 
   if not ir_util.field_is_virtual(field):
     # If the field is not virtual, writes are physical.
-    field.write_method.physical = True
+    ir_data_utils.builder(field).write_method.physical = True
     return
 
+  field_builder = ir_data_utils.builder(field)
+
   # A virtual field cannot be a direct alias if it has an additional
   # requirement.
   requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
@@ -235,17 +238,17 @@
         _add_write_method(referenced_field, ir)
         reference_is_read_only = referenced_field.write_method.read_only
       if not reference_is_read_only:
-        field.write_method.transform.destination.CopyFrom(
+        field_builder.write_method.transform.destination.CopyFrom(
             field_reference.field_reference)
-        field.write_method.transform.function_body.CopyFrom(function_body)
+        field_builder.write_method.transform.function_body.CopyFrom(function_body)
       else:
         # If the virtual field's expression is invertible, but its target field
         # is read-only, it is also read-only.
-        field.write_method.read_only = True
+        field_builder.write_method.read_only = True
     else:
       # If the virtual field's expression is not invertible, it is
       # read-only.
-      field.write_method.read_only = True
+      field_builder.write_method.read_only = True
     return
 
   referenced_field = ir_util.find_object(
@@ -253,17 +256,17 @@
   if not isinstance(referenced_field, ir_data.Field):
     # If the virtual field aliases a non-field (i.e., a parameter), it is
     # read-only.
-    field.write_method.read_only = True
+    field_builder.write_method.read_only = True
     return
 
   _add_write_method(referenced_field, ir)
   if referenced_field.write_method.read_only:
     # If the virtual field directly aliases a read-only field, it is read-only.
-    field.write_method.read_only = True
+    field_builder.write_method.read_only = True
     return
 
   # Otherwise, it can be written as a direct alias.
-  field.write_method.alias.CopyFrom(
+  field_builder.write_method.alias.CopyFrom(
       field.read_transform.field_reference)
 
 
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index 63b55b7..ed158b9 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -29,3 +29,11 @@
   def from_json(data_cls: type[ir_data.Message], data):
     """Constructs an IR data class from the given JSON string"""
     return data_cls.from_json(data)
+
+
+def builder(target: ir_data.Message) -> ir_data.Message:
+  """Provides a wrapper for building up IR data classes.
+
+  This is a no-op and just used for annotation for now.
+  """
+  return target
diff --git a/compiler/util/ir_util_test.py b/compiler/util/ir_util_test.py
index b92ffb9..8e0b37a 100644
--- a/compiler/util/ir_util_test.py
+++ b/compiler/util/ir_util_test.py
@@ -32,14 +32,14 @@
     self.assertTrue(ir_util.is_constant(_parse_expression("6")))
     expression = _parse_expression("12")
     # The type information should be ignored for constants like this one.
-    expression.type.integer.CopyFrom(ir_data.IntegerType())
+    ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
     self.assertTrue(ir_util.is_constant(expression))
 
   def test_is_constant_boolean(self):
     self.assertTrue(ir_util.is_constant(_parse_expression("true")))
     expression = _parse_expression("true")
     # The type information should be ignored for constants like this one.
-    expression.type.boolean.CopyFrom(ir_data.BooleanType())
+    ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
     self.assertTrue(ir_util.is_constant(expression))
 
   def test_is_constant_enum(self):