Merge pull request #143 from EricRahm/add_builder
Add `ir_data_utils.builder`
diff --git a/compiler/back_end/cpp/header_generator_test.py b/compiler/back_end/cpp/header_generator_test.py
index daea7ea..d58f798 100644
--- a/compiler/back_end/cpp/header_generator_test.py
+++ b/compiler/back_end/cpp/header_generator_test.py
@@ -19,6 +19,7 @@
from compiler.front_end import glue
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import test_util
def _make_ir_from_emb(emb_text, name="m.emb"):
@@ -95,6 +96,7 @@
attr = ir.module[0].type[0].attribute[0]
bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
bad_case_source_location.CopyFrom(attr.value.source_location)
# Location of SHORTY_CASE in the attribute line.
bad_case_source_location.start.column = 30
@@ -114,6 +116,7 @@
attr = ir.module[0].type[0].attribute[0]
bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
bad_case_source_location.CopyFrom(attr.value.source_location)
# Location of bad_CASE in the attribute line.
bad_case_source_location.start.column = 43
@@ -133,6 +136,7 @@
attr = ir.module[0].type[0].attribute[0]
bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
bad_case_source_location.CopyFrom(attr.value.source_location)
# Location of BAD_case in the attribute line.
bad_case_source_location.start.column = 55
@@ -152,6 +156,7 @@
attr = ir.module[0].type[0].attribute[0]
bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
bad_case_source_location.CopyFrom(attr.value.source_location)
# Location of the second SHOUTY_CASE in the attribute line.
bad_case_source_location.start.column = 43
@@ -172,6 +177,7 @@
attr = ir.module[0].type[0].attribute[0]
bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
bad_case_source_location.CopyFrom(attr.value.source_location)
# Location of excess comma.
bad_case_source_location.start.column = 42
diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py
index e9b423b..cd6a851 100644
--- a/compiler/front_end/expression_bounds.py
+++ b/compiler/front_end/expression_bounds.py
@@ -19,6 +19,7 @@
import operator
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -65,6 +66,7 @@
def _compute_constant_value_of_constant_reference(expression, ir):
referred_object = ir_util.find_object(
expression.constant_reference.canonical_name, ir)
+ expression = ir_data_utils.builder(expression)
if isinstance(referred_object, ir_data.EnumValue):
compute_constraints_of_expression(referred_object.value, ir)
assert ir_util.is_constant(referred_object.value)
@@ -111,7 +113,7 @@
field_path = expression.function.args[0].field_reference.path[-1]
field = ir_util.find_object(field_path, ir)
compute_constraints_of_expression(field.existence_condition, ir)
- expression.type.CopyFrom(field.existence_condition.type)
+ ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type)
def _compute_constraints_of_field_reference(expression, ir):
@@ -122,7 +124,7 @@
# References to virtual fields should have the virtual field's constraints
# copied over.
compute_constraints_of_expression(field.read_transform, ir)
- expression.type.CopyFrom(field.read_transform.type)
+ ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
return
# Non-virtual non-integer fields do not (yet) have constraints.
if expression.type.WhichOneof("type") == "integer":
@@ -633,6 +635,7 @@
def _compute_constraints_of_choice_operator(expression):
"""Computes the constraints of a choice operation '?:'."""
condition, if_true, if_false = expression.function.args
+ expression = ir_data_utils.builder(expression)
if condition.type.boolean.HasField("value"):
# The generated expressions for $size_in_bits and $size_in_bytes look like
#
diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py
index ed27a1e..163bf6e 100644
--- a/compiler/front_end/module_ir.py
+++ b/compiler/front_end/module_ir.py
@@ -26,6 +26,7 @@
import sys
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import name_conversion
from compiler.util import parser_types
@@ -793,9 +794,9 @@
def _structure(struct, name, parameters, colon, comment, newline, struct_body):
"""Composes the top-level IR for an Emboss structure."""
del colon, comment, newline # Unused.
- struct_body.structure.source_location.start.CopyFrom(
+ ir_data_utils.builder(struct_body.structure).source_location.start.CopyFrom(
struct.source_location.start)
- struct_body.structure.source_location.end.CopyFrom(
+ ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom(
struct_body.source_location.end)
struct_body.name.CopyFrom(name)
if parameters.list:
@@ -879,11 +880,11 @@
' unconditional-anonymous-bits-field anonymous-bits-field-block')
def _unconditional_block_plus_field_block(field, block):
"""Prepends an unconditional field to block."""
- field.field.existence_condition.source_location.CopyFrom(
+ ir_data_utils.builder(field.field).existence_condition.source_location.CopyFrom(
field.source_location)
- field.field.existence_condition.boolean_constant.source_location.CopyFrom(
+ ir_data_utils.builder(field.field).existence_condition.boolean_constant.source_location.CopyFrom(
field.source_location)
- field.field.existence_condition.boolean_constant.value = True
+ ir_data_utils.builder(field.field).existence_condition.boolean_constant.value = True
return _List([field] + block.list)
@@ -929,7 +930,7 @@
"""Applies an existence_condition to each element of fields."""
del if_keyword, newline, colon, comment, indent, dedent # Unused.
for field in fields.list:
- condition = field.field.existence_condition
+ condition = ir_data_utils.builder(field.field).existence_condition
condition.CopyFrom(expression)
condition.source_location.is_disjoint_from_parent = True
return fields
@@ -967,11 +968,12 @@
newline, field_body):
"""Constructs an ir_data.Field from the given components."""
del comment # Unused
- field = ir_data.Field(location=location,
+ field_ir = ir_data.Field(location=location,
type=field_type,
name=name,
attribute=attributes.list,
documentation=doc.list)
+ field = ir_data_utils.builder(field_ir)
if field_body.list:
field.attribute.extend(field_body.list[0].attribute)
field.documentation.extend(field_body.list[0].documentation)
@@ -982,7 +984,7 @@
field.source_location.end.CopyFrom(field_body.source_location.end)
else:
field.source_location.end.CopyFrom(newline.source_location.end)
- return _FieldWithType(field=field)
+ return _FieldWithType(field=field_ir)
# A "virtual field" is:
@@ -996,7 +998,8 @@
def _virtual_field(let, name, equals, value, comment, newline, field_body):
"""Constructs an ir_data.Field from the given components."""
del equals, comment # Unused
- field = ir_data.Field(read_transform=value, name=name)
+ field_ir = ir_data.Field(read_transform=value, name=name)
+ field = ir_data_utils.builder(field_ir)
if field_body.list:
field.attribute.extend(field_body.list[0].attribute)
field.documentation.extend(field_body.list[0].documentation)
@@ -1005,7 +1008,7 @@
field.source_location.end.CopyFrom(field_body.source_location.end)
else:
field.source_location.end.CopyFrom(newline.source_location.end)
- return _FieldWithType(field=field)
+ return _FieldWithType(field=field_ir)
# An inline enum is:
@@ -1047,17 +1050,18 @@
def _inline_type_field(location, name, abbreviation, body):
"""Shared implementation of _inline_enum_field and _anonymous_bit_field."""
- field = ir_data.Field(location=location,
+ field_ir = ir_data.Field(location=location,
name=name,
attribute=body.attribute,
documentation=body.documentation)
+ field = ir_data_utils.builder(field_ir)
# All attributes should be attached to the field, not the type definition: if
# the user wants to use type attributes, they should create a separate type
# definition and reference it.
del body.attribute[:]
type_name = ir_data.NameDefinition()
type_name.CopyFrom(name)
- type_name.name.text = name_conversion.snake_to_camel(type_name.name.text)
+ ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text)
field.type.atomic_type.reference.source_name.extend([type_name.name])
field.type.atomic_type.reference.source_location.CopyFrom(
type_name.source_location)
@@ -1067,17 +1071,17 @@
if abbreviation.list:
field.abbreviation.CopyFrom(abbreviation.list[0])
field.source_location.start.CopyFrom(location.source_location.start)
- body.source_location.start.CopyFrom(location.source_location.start)
+ ir_data_utils.builder(body.source_location).start.CopyFrom(location.source_location.start)
if body.HasField('enumeration'):
- body.enumeration.source_location.CopyFrom(body.source_location)
+ ir_data_utils.builder(body.enumeration).source_location.CopyFrom(body.source_location)
else:
assert body.HasField('structure')
- body.structure.source_location.CopyFrom(body.source_location)
- body.name.CopyFrom(type_name)
+ ir_data_utils.builder(body.structure).source_location.CopyFrom(body.source_location)
+ ir_data_utils.builder(body).name.CopyFrom(type_name)
field.source_location.end.CopyFrom(body.source_location.end)
subtypes = [body] + list(body.subtype)
del body.subtype[:]
- return _FieldWithType(field=field, subtypes=subtypes)
+ return _FieldWithType(field=field_ir, subtypes=subtypes)
@_handles('anonymous-bits-field-definition ->'
@@ -1113,11 +1117,11 @@
@_handles('enum -> "enum" type-name ":" Comment? eol enum-body')
def _enum(enum, name, colon, comment, newline, enum_body):
del colon, comment, newline # Unused.
- enum_body.enumeration.source_location.start.CopyFrom(
+ ir_data_utils.builder(enum_body.enumeration).source_location.start.CopyFrom(
enum.source_location.start)
- enum_body.enumeration.source_location.end.CopyFrom(
+ ir_data_utils.builder(enum_body.enumeration).source_location.end.CopyFrom(
enum_body.source_location.end)
- enum_body.name.CopyFrom(name)
+ ir_data_utils.builder(enum_body).name.CopyFrom(name)
return enum_body
@@ -1161,7 +1165,7 @@
@_handles('external -> "external" type-name ":" Comment? eol external-body')
def _external(external, name, colon, comment, newline, external_body):
del colon, comment, newline # Unused.
- external_body.source_location.start.CopyFrom(external.source_location.start)
+ ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start)
external_body.name.CopyFrom(name)
return external_body
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 6f7c030..09675b8 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -22,6 +22,7 @@
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -133,7 +134,7 @@
"""Adds the given name_ir to scope and sets its canonical_name."""
name = name_ir.name.text
canonical_name = _nested_name(scope.canonical_name, name)
- name_ir.canonical_name.CopyFrom(canonical_name)
+ ir_data_utils.builder(name_ir).canonical_name.CopyFrom(canonical_name)
return _add_name_to_scope(name_ir.name, scope, canonical_name, visibility,
errors)
@@ -282,7 +283,7 @@
visible_scopes, source_file_name, errors)
if target is not None:
assert not target.alias
- reference.canonical_name.CopyFrom(target.canonical_name)
+ ir_data_utils.builder(reference).canonical_name.CopyFrom(target.canonical_name)
def _find_target_of_reference(reference, table, current_scope, visible_scopes,
@@ -419,7 +420,7 @@
member_name = ir_data.CanonicalName()
member_name.CopyFrom(
previous_field.type.atomic_type.reference.canonical_name)
- member_name.object_path.extend([ref.source_name[0].text])
+ ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text])
previous_field = ir_util.find_object_or_none(member_name, ir)
if previous_field is None:
errors.append(
@@ -427,7 +428,7 @@
ref.source_name[0].source_location,
ref.source_name[0].text))
return
- ref.canonical_name.CopyFrom(member_name)
+ ir_data_utils.builder(ref).canonical_name.CopyFrom(member_name)
previous_reference = ref
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index 7f6aabb..ac79c47 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -18,6 +18,7 @@
from compiler.util import error
from compiler.util import expression_parser
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -27,7 +28,7 @@
if not isinstance(proto, ir_data.Message):
return
if hasattr(proto, "source_location"):
- proto.source_location.is_synthetic = True
+ ir_data_utils.builder(proto).source_location.is_synthetic = True
for name, value in proto.raw_fields.items():
if name != "source_location":
if isinstance(value, ir_data.TypedScopedList):
@@ -112,7 +113,7 @@
)
new_existence_condition = ir_data.Expression()
new_existence_condition.CopyFrom(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
- existence_clauses = new_existence_condition.function.args
+ existence_clauses = ir_data_utils.builder(new_existence_condition).function.args
existence_clauses[0].function.args[0].field_reference.CopyFrom(
anonymous_field_reference)
existence_clauses[1].function.args[0].field_reference.CopyFrom(
@@ -129,7 +130,7 @@
existence_condition=new_existence_condition,
name=subfield.name)
if subfield.HasField("abbreviation"):
- new_alias.abbreviation.CopyFrom(subfield.abbreviation)
+ ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation)
_mark_as_synthetic(new_alias.existence_condition)
_mark_as_synthetic(new_alias.read_transform)
new_fields.append(new_alias)
@@ -196,6 +197,7 @@
continue
size_clause = ir_data.Expression()
size_clause.CopyFrom(_SIZE_CLAUSE_SKELETON)
+ size_clause = ir_data_utils.builder(size_clause)
# Copy the appropriate clauses into `existence_condition ? start + size : 0`
size_clause.function.args[0].CopyFrom(field.existence_condition)
size_clause.function.args[1].function.args[0].CopyFrom(field.location.start)
@@ -221,20 +223,21 @@
_NEXT_KEYWORD_REPLACEMENT_EXPRESSION = expression_parser.parse("x + y")
-def _maybe_replace_next_keyword_in_expression(expression, last_location,
+def _maybe_replace_next_keyword_in_expression(expression_ir, last_location,
source_file_name, errors):
- if not expression.HasField("builtin_reference"):
+ if not expression_ir.HasField("builtin_reference"):
return
- if expression.builtin_reference.canonical_name.object_path[0] != "$next":
+ if expression_ir.builtin_reference.canonical_name.object_path[0] != "$next":
return
if not last_location:
errors.append([
- error.error(source_file_name, expression.source_location,
+ error.error(source_file_name, expression_ir.source_location,
"`$next` may not be used in the first physical field of a " +
"structure; perhaps you meant `0`?")
])
return
- original_location = expression.source_location
+ original_location = expression_ir.source_location
+ expression = ir_data_utils.builder(expression_ir)
expression.CopyFrom(_NEXT_KEYWORD_REPLACEMENT_EXPRESSION)
expression.function.args[0].CopyFrom(last_location.start)
expression.function.args[1].CopyFrom(last_location.size)
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index 727989f..d8a226a 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -17,6 +17,7 @@
from compiler.front_end import attributes
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -44,11 +45,11 @@
def _annotate_as_integer(expression):
- expression.type.integer.CopyFrom(ir_data.IntegerType())
+ ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
def _annotate_as_boolean(expression):
- expression.type.boolean.CopyFrom(ir_data.BooleanType())
+ ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
def _type_check(expression, source_file_name, errors, type_oneof, type_name,
@@ -88,7 +89,7 @@
referred_name = expression.constant_reference.canonical_name
referred_object = ir_util.find_object(referred_name, ir)
if isinstance(referred_object, ir_data.EnumValue):
- expression.type.enumeration.name.CopyFrom(expression.constant_reference)
+ ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(expression.constant_reference)
del expression.type.enumeration.name.canonical_name.object_path[-1]
elif isinstance(referred_object, ir_data.Field):
if not ir_util.field_is_virtual(referred_object):
@@ -102,7 +103,7 @@
return
_type_check_expression(referred_object.read_transform,
referred_name.module_file, ir, errors)
- expression.type.CopyFrom(referred_object.read_transform.type)
+ ir_data_utils.builder(expression).type.CopyFrom(referred_object.read_transform.type)
else:
assert False, "Unexpected constant reference type."
@@ -189,10 +190,10 @@
if ir_util.field_is_virtual(field):
_type_check_expression(field.read_transform,
expression.field_reference.path[0], ir, errors)
- expression.type.CopyFrom(field.read_transform.type)
+ ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
return
if not field.type.HasField("atomic_type"):
- expression.type.opaque.CopyFrom(ir_data.OpaqueType())
+ ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType())
else:
_set_expression_type_from_physical_type_reference(
expression, field.type.atomic_type.reference, ir)
@@ -232,7 +233,7 @@
"""Sets the type of an expression to match a physical type."""
field_type = ir_util.find_object(type_reference, ir)
assert field_type, "Field type should be non-None after name resolution."
- expression.type.CopyFrom(
+ ir_data_utils.builder(expression).type.CopyFrom(
unbounded_expression_type_for_physical_type(field_type))
@@ -323,7 +324,7 @@
elif if_true.type.WhichOneof("type") == "boolean":
_annotate_as_boolean(expression)
elif if_true.type.WhichOneof("type") == "enumeration":
- expression.type.enumeration.name.CopyFrom(if_true.type.enumeration.name)
+ ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(if_true.type.enumeration.name)
else:
assert False, "Unexpected type for if_true."
diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py
index ac58b34..db555bc 100644
--- a/compiler/front_end/write_inference.py
+++ b/compiler/front_end/write_inference.py
@@ -17,6 +17,7 @@
from compiler.front_end import attributes
from compiler.front_end import expression_bounds
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -216,9 +217,11 @@
if not ir_util.field_is_virtual(field):
# If the field is not virtual, writes are physical.
- field.write_method.physical = True
+ ir_data_utils.builder(field).write_method.physical = True
return
+ field_builder = ir_data_utils.builder(field)
+
# A virtual field cannot be a direct alias if it has an additional
# requirement.
requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
@@ -235,17 +238,17 @@
_add_write_method(referenced_field, ir)
reference_is_read_only = referenced_field.write_method.read_only
if not reference_is_read_only:
- field.write_method.transform.destination.CopyFrom(
+ field_builder.write_method.transform.destination.CopyFrom(
field_reference.field_reference)
- field.write_method.transform.function_body.CopyFrom(function_body)
+ field_builder.write_method.transform.function_body.CopyFrom(function_body)
else:
# If the virtual field's expression is invertible, but its target field
# is read-only, it is also read-only.
- field.write_method.read_only = True
+ field_builder.write_method.read_only = True
else:
# If the virtual field's expression is not invertible, it is
# read-only.
- field.write_method.read_only = True
+ field_builder.write_method.read_only = True
return
referenced_field = ir_util.find_object(
@@ -253,17 +256,17 @@
if not isinstance(referenced_field, ir_data.Field):
# If the virtual field aliases a non-field (i.e., a parameter), it is
# read-only.
- field.write_method.read_only = True
+ field_builder.write_method.read_only = True
return
_add_write_method(referenced_field, ir)
if referenced_field.write_method.read_only:
# If the virtual field directly aliases a read-only field, it is read-only.
- field.write_method.read_only = True
+ field_builder.write_method.read_only = True
return
# Otherwise, it can be written as a direct alias.
- field.write_method.alias.CopyFrom(
+ field_builder.write_method.alias.CopyFrom(
field.read_transform.field_reference)
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index 63b55b7..ed158b9 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -29,3 +29,11 @@
def from_json(data_cls: type[ir_data.Message], data):
"""Constructs an IR data class from the given JSON string"""
return data_cls.from_json(data)
+
+
+def builder(target: ir_data.Message) -> ir_data.Message:
+ """Provides a wrapper for building up IR data classes.
+
+ This is a no-op and just used for annotation for now.
+ """
+ return target
diff --git a/compiler/util/ir_util_test.py b/compiler/util/ir_util_test.py
index b92ffb9..8e0b37a 100644
--- a/compiler/util/ir_util_test.py
+++ b/compiler/util/ir_util_test.py
@@ -32,14 +32,14 @@
self.assertTrue(ir_util.is_constant(_parse_expression("6")))
expression = _parse_expression("12")
# The type information should be ignored for constants like this one.
- expression.type.integer.CopyFrom(ir_data.IntegerType())
+ ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
self.assertTrue(ir_util.is_constant(expression))
def test_is_constant_boolean(self):
self.assertTrue(ir_util.is_constant(_parse_expression("true")))
expression = _parse_expression("true")
# The type information should be ignored for constants like this one.
- expression.type.boolean.CopyFrom(ir_data.BooleanType())
+ ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
self.assertTrue(ir_util.is_constant(expression))
def test_is_constant_enum(self):