Merge pull request #71 from reventlov/more_enum_comparisons

Allow all comparison operators for enums.
diff --git a/compiler/back_end/cpp/testcode/condition_test.cc b/compiler/back_end/cpp/testcode/condition_test.cc
index 588a6f7..1645f48 100644
--- a/compiler/back_end/cpp/testcode/condition_test.cc
+++ b/compiler/back_end/cpp/testcode/condition_test.cc
@@ -439,7 +439,10 @@
   EXPECT_TRUE(writer.Ok());
   ASSERT_TRUE(writer.SizeIsKnown());
   EXPECT_EQ(2U, writer.SizeInBytes());
+  EXPECT_TRUE(writer.has_xc().Value());
   EXPECT_EQ(0, writer.xc().Read());
+  EXPECT_TRUE(writer.has_xc2().Value());
+  EXPECT_EQ(0, writer.xc2().Read());
 }
 
 TEST(Conditional, FalseEnumBasedCondition) {
@@ -449,6 +452,9 @@
   ASSERT_TRUE(writer.SizeIsKnown());
   EXPECT_EQ(1U, writer.SizeInBytes());
   EXPECT_FALSE(writer.xc().Ok());
+  EXPECT_FALSE(writer.has_xc().Value());
+  EXPECT_FALSE(writer.xc2().Ok());
+  EXPECT_FALSE(writer.has_xc2().Value());
 }
 
 TEST(Conditional, TrueEnumBasedNegativeCondition) {
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index 5369ca6..edc9d93 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -111,8 +111,10 @@
   for arg in expression.function.args:
     _type_check_expression(arg, source_file_name, ir, errors)
   function = expression.function.function
-  if function in (ir_pb2.Function.EQUALITY, ir_pb2.Function.INEQUALITY):
-    _type_check_equality_operator(expression, source_file_name, errors)
+  if function in (ir_pb2.Function.EQUALITY, ir_pb2.Function.INEQUALITY,
+                  ir_pb2.Function.LESS, ir_pb2.Function.LESS_OR_EQUAL,
+                  ir_pb2.Function.GREATER, ir_pb2.Function.GREATER_OR_EQUAL):
+    _type_check_comparison_operator(expression, source_file_name, errors)
   elif function == ir_pb2.Function.CHOICE:
     _type_check_choice_operator(expression, source_file_name, errors)
   else:
@@ -138,13 +140,6 @@
                                        "operator"),
       ir_pb2.Function.AND: (bool_result, bool_args, binary, 2, 2, "operator"),
       ir_pb2.Function.OR: (bool_result, bool_args, binary, 2, 2, "operator"),
-      ir_pb2.Function.LESS: (bool_result, int_args, binary, 2, 2, "operator"),
-      ir_pb2.Function.LESS_OR_EQUAL: (bool_result, int_args, binary, 2, 2,
-                                      "operator"),
-      ir_pb2.Function.GREATER: (bool_result, int_args, binary, 2, 2,
-                                "operator"),
-      ir_pb2.Function.GREATER_OR_EQUAL: (bool_result, int_args, binary, 2, 2,
-                                         "operator"),
       ir_pb2.Function.MAXIMUM: (int_result, int_args, n_ary, 1, None,
                                 "function"),
       ir_pb2.Function.PRESENCE: (bool_result, field_args, n_ary, 1, 1,
@@ -268,18 +263,28 @@
     assert False, "_types_are_compatible works with enums, integers, booleans."
 
 
-def _type_check_equality_operator(expression, source_file_name, errors):
-  """Checks the type of an equality operator (== or !=)."""
+def _type_check_comparison_operator(expression, source_file_name, errors):
+  """Checks the type of a comparison operator (==, !=, <, >, >=, <=)."""
+  # Applying less than or greater than to a boolean is likely a mistake, so
+  # only equality and inequality are allowed for booleans.
+  if expression.function.function in (ir_pb2.Function.EQUALITY,
+                                      ir_pb2.Function.INEQUALITY):
+    acceptable_types = ("integer", "boolean", "enumeration")
+    acceptable_types_for_humans = "an integer, boolean, or enum"
+  else:
+    acceptable_types = ("integer", "enumeration")
+    acceptable_types_for_humans = "an integer or enum"
   left = expression.function.args[0]
-  if left.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"):
-    errors.append([
-        error.error(source_file_name, left.source_location,
-                    "Left argument of operator '{}' must be an integer, "
-                    "boolean, or enum.".format(
-                        expression.function.function_name.text))
-    ])
-    return
   right = expression.function.args[1]
+  for (argument, name) in ((left, "Left"), (right, "Right")):
+    if argument.type.WhichOneof("type") not in acceptable_types:
+      errors.append([
+          error.error(source_file_name, argument.source_location,
+                      "{} argument of operator '{}' must be {}.".format(
+                          name, expression.function.function_name.text,
+                          acceptable_types_for_humans))
+      ])
+      return
   if not _types_are_compatible(left, right):
     errors.append([
         error.error(source_file_name, expression.source_location,
diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py
index 6faaa25..c5f174f 100644
--- a/compiler/front_end/type_check_test.py
+++ b/compiler/front_end/type_check_test.py
@@ -99,6 +99,20 @@
     self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
                      "enumeration")
 
+  def test_adds_enum_comparison_operation_type(self):
+    ir = self._make_ir("struct Foo:\n"
+                       "  0 [+1]                   UInt      x\n"
+                       "  1 [+Enum.VAL>=Enum.VAL]  UInt:8[]  y\n"
+                       "enum Enum:\n"
+                       "  VAL = 1\n")
+    self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+    expression = ir.module[0].type[0].structure.field[1].location.size
+    self.assertEqual(expression.type.WhichOneof("type"), "boolean")
+    self.assertEqual(expression.function.args[0].type.WhichOneof("type"),
+                     "enumeration")
+    self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
+                     "enumeration")
+
   def test_adds_integer_field_type(self):
     ir = self._make_ir("struct Foo:\n"
                        "  0 [+1]     UInt      x\n"
@@ -172,9 +186,9 @@
                        "  1 [+1==x]    UInt:8[]  y\n")
     expression = ir.module[0].type[0].structure.field[1].location.size
     self.assertEqual([
-        [error.error("m.emb", expression.source_location,
-                     "Both arguments of operator '==' must have the same "
-                     "type.")]
+        [error.error("m.emb", expression.function.args[1].source_location,
+                     "Right argument of operator '==' must be an integer, "
+                     "boolean, or enum.")]
     ], error.filter_errors(type_check.annotate_types(ir)))
 
   def test_error_on_equality_mismatched_operands_int_bool(self):
@@ -188,14 +202,27 @@
                      "type.")]
     ], error.filter_errors(type_check.annotate_types(ir)))
 
-  def test_error_on_equality_mismatched_operands_bool_int(self):
+  def test_error_on_mismatched_comparison_operands(self):
     ir = self._make_ir("struct Foo:\n"
-                       "  0 [+1]         UInt      x\n"
-                       "  1 [+true==1]   UInt:8[]  y\n")
+                       "  0 [+1]           UInt:8    x\n"
+                       "  1 [+x>=Bar.BAR]  UInt:8[]  y\n"
+                       "enum Bar:\n"
+                       "  BAR = 1\n")
     expression = ir.module[0].type[0].structure.field[1].location.size
     self.assertEqual([
         [error.error("m.emb", expression.source_location,
-                     "Both arguments of operator '==' must have the same "
+                     "Both arguments of operator '>=' must have the same "
+                     "type.")]
+    ], error.filter_errors(type_check.annotate_types(ir)))
+
+  def test_error_on_equality_mismatched_operands_bool_int(self):
+    ir = self._make_ir("struct Foo:\n"
+                       "  0 [+1]         UInt      x\n"
+                       "  1 [+true!=1]   UInt:8[]  y\n")
+    expression = ir.module[0].type[0].structure.field[1].location.size
+    self.assertEqual([
+        [error.error("m.emb", expression.source_location,
+                     "Both arguments of operator '!=' must have the same "
                      "type.")]
     ], error.filter_errors(type_check.annotate_types(ir)))
 
@@ -322,18 +349,20 @@
                        "  1 [+true<1]  UInt:8[]  y\n")
     expression = ir.module[0].type[0].structure.field[1].location.size
     self.assertEqual([
-        [error.error("m.emb", expression.function.args[0].source_location,
-                     "Left argument of operator '<' must be an integer.")]
+        [error.error(
+            "m.emb", expression.function.args[0].source_location,
+            "Left argument of operator '<' must be an integer or enum.")]
     ], error.filter_errors(type_check.annotate_types(ir)))
 
   def test_error_on_bad_right_comparison_operand_type(self):
     ir = self._make_ir("struct Foo:\n"
-                       "  0 [+1]       UInt      x\n"
+                       "  0 [+1]        UInt      x\n"
                        "  1 [+1>=true]  UInt:8[]  y\n")
     expression = ir.module[0].type[0].structure.field[1].location.size
     self.assertEqual([
-        [error.error("m.emb", expression.function.args[1].source_location,
-                     "Right argument of operator '>=' must be an integer.")]
+        [error.error(
+            "m.emb", expression.function.args[1].source_location,
+            "Right argument of operator '>=' must be an integer or enum.")]
     ], error.filter_errors(type_check.annotate_types(ir)))
 
   def test_error_on_bad_boolean_operand_type(self):
diff --git a/testdata/condition.emb b/testdata/condition.emb
index acfa304..4d9f2a5 100644
--- a/testdata/condition.emb
+++ b/testdata/condition.emb
@@ -112,6 +112,8 @@
   0 [+1]    OnOff  x
   if x == OnOff.ON:
     1 [+1]  UInt   xc
+  if x > OnOff.OFF:
+    1 [+1]  UInt   xc2
 
 
 struct NegativeEnumCondition: