Merge pull request #145 from EricRahm/add_copy_and_update

Add `ir_data_utils.copy` and `update`
diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py
index 55c0a6e..11fcd17 100644
--- a/compiler/back_end/cpp/header_generator.py
+++ b/compiler/back_end/cpp/header_generator.py
@@ -1465,8 +1465,7 @@
   Offset should be a tuple of (start, end), which are the offsets relative to
   source_location.start.column to set the new start.column and end.column."""
 
-  new_location = ir_data.Location()
-  new_location.CopyFrom(source_location)
+  new_location = ir_data_utils.copy(source_location)
   new_location.start.column = source_location.start.column + offset[0]
   new_location.end.column = source_location.start.column + offset[1]
 
diff --git a/compiler/front_end/glue.py b/compiler/front_end/glue.py
index 7724da9..a1e1a5b 100644
--- a/compiler/front_end/glue.py
+++ b/compiler/front_end/glue.py
@@ -143,8 +143,7 @@
   # need to re-parse the prelude for every test .emb.
   if (source_code, file_name) in _cached_modules:
     debug_info = _cached_modules[source_code, file_name]
-    ir = ir_data.Module()
-    ir.CopyFrom(debug_info.ir)
+    ir = ir_data_utils.copy(debug_info.ir)
   else:
     debug_info = ModuleDebugInfo(file_name)
     debug_info.source_code = source_code
@@ -163,8 +162,7 @@
     ir = module_ir.build_ir(parse_result.parse_tree, used_productions)
     ir.source_text = source_code
     debug_info.used_productions = used_productions
-    debug_info.ir = ir_data.Module()
-    debug_info.ir.CopyFrom(ir)
+    debug_info.ir = ir_data_utils.copy(ir)
     _cached_modules[source_code, file_name] = debug_info
   ir.source_file_name = file_name
   return _IrDebugInfo(ir, debug_info, [])
diff --git a/compiler/front_end/glue_test.py b/compiler/front_end/glue_test.py
index 10613d7..2f2ddc5 100644
--- a/compiler/front_end/glue_test.py
+++ b/compiler/front_end/glue_test.py
@@ -141,8 +141,7 @@
     self.assertFalse(ir)
 
   def test_ir_from_parse_module(self):
-    log_file_path_ir = ir_data.Module()
-    log_file_path_ir.CopyFrom(_SPAN_SE_LOG_FILE_IR)
+    log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR)
     log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH
     self.assertEqual(log_file_path_ir, glue.parse_module(
         _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir)
diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py
index 163bf6e..c9ba765 100644
--- a/compiler/front_end/module_ir.py
+++ b/compiler/front_end/module_ir.py
@@ -150,7 +150,10 @@
     used_productions.add(parse_tree.production)
     result = _handlers[parse_tree.production](*parsed_children)
     if parse_tree.source_location is not None:
-      result.source_location.CopyFrom(parse_tree.source_location)
+      if result.source_location:
+        ir_data_utils.update(result.source_location, parse_tree.source_location)
+      else:
+        result.source_location = ir_data_utils.copy(parse_tree.source_location)
     return result
   else:
     # For leaf nodes, the temporary "IR" is just the token.  Higher-level rules
@@ -798,7 +801,10 @@
       struct.source_location.start)
   ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom(
       struct_body.source_location.end)
-  struct_body.name.CopyFrom(name)
+  if struct_body.name:
+    ir_data_utils.update(struct_body.name, name)
+  else:
+    struct_body.name = ir_data_utils.copy(name)
   if parameters.list:
     struct_body.runtime_parameter.extend(parameters.list[0].list)
   return struct_body
@@ -1059,8 +1065,7 @@
   # the user wants to use type attributes, they should create a separate type
   # definition and reference it.
   del body.attribute[:]
-  type_name = ir_data.NameDefinition()
-  type_name.CopyFrom(name)
+  type_name = ir_data_utils.copy(name)
   ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text)
   field.type.atomic_type.reference.source_name.extend([type_name.name])
   field.type.atomic_type.reference.source_location.CopyFrom(
@@ -1166,7 +1171,10 @@
 def _external(external, name, colon, comment, newline, external_body):
   del colon, comment, newline  # Unused.
   ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start)
-  external_body.name.CopyFrom(name)
+  if external_body.name:
+    ir_data_utils.update(external_body.name, name)
+  else:
+    external_body.name = ir_data_utils.copy(name)
   return external_body
 
 
@@ -1218,7 +1226,7 @@
       atomic_type_source_location_end)
   t = ir_data.Type(
       atomic_type=ir_data.AtomicType(
-          reference=reference,
+          reference=ir_data_utils.copy(reference),
           source_location=atomic_type_location,
           runtime_parameter=parameters.list[0].list if parameters.list else []),
       size_in_bits=size.list[0] if size.list else None,
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 09675b8..5990f13 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -417,8 +417,7 @@
                                previous_reference.source_name[0].text))
       return
     assert previous_field.type.WhichOneof("type") == "atomic_type"
-    member_name = ir_data.CanonicalName()
-    member_name.CopyFrom(
+    member_name = ir_data_utils.copy(
         previous_field.type.atomic_type.reference.canonical_name)
     ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text])
     previous_field = ir_util.find_object_or_none(member_name, ir)
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index d1bef9a..42b3cff 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -111,15 +111,14 @@
               ir_data.Reference(source_name=[subfield.name.name]),
           ]
       )
-      new_existence_condition = ir_data.Expression()
-      new_existence_condition.CopyFrom(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
+      new_existence_condition = ir_data_utils.copy(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
       existence_clauses = ir_data_utils.builder(new_existence_condition).function.args
       existence_clauses[0].function.args[0].field_reference.CopyFrom(
           anonymous_field_reference)
       existence_clauses[1].function.args[0].field_reference.CopyFrom(
           alias_field_reference)
       new_read_transform = ir_data.Expression(
-          field_reference=alias_field_reference)
+          field_reference=ir_data_utils.copy(alias_field_reference))
       # This treats *most* of the alias field as synthetic, but not its name(s):
       # leaving the name(s) as "real" means that symbol collisions with the
       # surrounding structure will be properly reported to the user.
@@ -128,7 +127,7 @@
       new_alias = ir_data.Field(
           read_transform=new_read_transform,
           existence_condition=new_existence_condition,
-          name=subfield.name)
+          name=ir_data_utils.copy(subfield.name))
       if subfield.HasField("abbreviation"):
         ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation)
       _mark_as_synthetic(new_alias.existence_condition)
@@ -195,16 +194,14 @@
     # to the size of the structure.
     if ir_util.field_is_virtual(field):
       continue
-    size_clause = ir_data.Expression()
-    size_clause.CopyFrom(_SIZE_CLAUSE_SKELETON)
-    size_clause = ir_data_utils.builder(size_clause)
+    size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON)
+    size_clause = ir_data_utils.builder(size_clause_ir)
     # Copy the appropriate clauses into `existence_condition ? start + size : 0`
     size_clause.function.args[0].CopyFrom(field.existence_condition)
     size_clause.function.args[1].function.args[0].CopyFrom(field.location.start)
     size_clause.function.args[1].function.args[1].CopyFrom(field.location.size)
-    size_clauses.append(size_clause)
-  size_expression = ir_data.Expression()
-  size_expression.CopyFrom(_SIZE_SKELETON)
+    size_clauses.append(size_clause_ir)
+  size_expression = ir_data_utils.copy(_SIZE_SKELETON)
   size_expression.function.args.extend(size_clauses)
   _mark_as_synthetic(size_expression)
   size_field = ir_data.Field(
diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py
index 063f38e..6e04280 100644
--- a/compiler/util/attribute_util.py
+++ b/compiler/util/attribute_util.py
@@ -305,8 +305,7 @@
   defaults = defaults.copy()
   for attr in obj.attribute:
     if attr.is_default:
-      defaulted_attr = ir_data.Attribute()
-      defaulted_attr.CopyFrom(attr)
+      defaulted_attr = ir_data_utils.copy(attr)
       defaulted_attr.is_default = False
       defaults[attr.name.text] = defaulted_attr
   return {"defaults": defaults}
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index 479d974..ac02bb0 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -45,3 +45,18 @@
   This is a no-op and just used for annotation for now.
   """
   return obj
+
+
+def copy(ir: ir_data.Message | None) -> ir_data.Message | None:
+  """Creates a copy of the given IR data class"""
+  if not ir:
+    return None
+  ir_class = type(ir)
+  ir_copy = ir_class()
+  update(ir_copy, ir)
+  return ir_copy
+
+
+def update(ir: ir_data.Message, template: ir_data.Message):
+  """Updates `ir`s fields with all set fields in the template."""
+  ir.CopyFrom(template)