Auto merge of #63658 - RalfJung:miri-op, r=oli-obk

Refactor Miri ops (unary, binary) to have more types

This is the part of https://github.com/rust-lang/rust/pull/63448 that is just a refactoring. It helps that PR by making it easier to perform machine arithmetic.

r? @oli-obk @eddyb
diff --git a/src/librustc_mir/const_eval.rs b/src/librustc_mir/const_eval.rs
index 52225ea..76ee76a 100644
--- a/src/librustc_mir/const_eval.rs
+++ b/src/librustc_mir/const_eval.rs
@@ -11,9 +11,8 @@
 use rustc::hir::def_id::DefId;
 use rustc::mir::interpret::{ConstEvalErr, ErrorHandled, ScalarMaybeUndef};
 use rustc::mir;
-use rustc::ty::{self, TyCtxt};
+use rustc::ty::{self, Ty, TyCtxt, subst::Subst};
 use rustc::ty::layout::{self, LayoutOf, VariantIdx};
-use rustc::ty::subst::Subst;
 use rustc::traits::Reveal;
 use rustc_data_structures::fx::FxHashMap;
 
@@ -415,7 +414,7 @@
         _bin_op: mir::BinOp,
         _left: ImmTy<'tcx>,
         _right: ImmTy<'tcx>,
-    ) -> InterpResult<'tcx, (Scalar, bool)> {
+    ) -> InterpResult<'tcx, (Scalar, bool, Ty<'tcx>)> {
         Err(
             ConstEvalError::NeedsRfc("pointer arithmetic or comparison".to_string()).into(),
         )
diff --git a/src/librustc_mir/interpret/intrinsics.rs b/src/librustc_mir/interpret/intrinsics.rs
index ee105fe..4c86c53 100644
--- a/src/librustc_mir/interpret/intrinsics.rs
+++ b/src/librustc_mir/interpret/intrinsics.rs
@@ -137,7 +137,7 @@
                 let l = self.read_immediate(args[0])?;
                 let r = self.read_immediate(args[1])?;
                 let is_add = intrinsic_name == "saturating_add";
-                let (val, overflowed) = self.binary_op(if is_add {
+                let (val, overflowed, _ty) = self.overflowing_binary_op(if is_add {
                     BinOp::Add
                 } else {
                     BinOp::Sub
@@ -184,7 +184,7 @@
                     "unchecked_shr" => BinOp::Shr,
                     _ => bug!("Already checked for int ops")
                 };
-                let (val, overflowed) = self.binary_op(bin_op, l, r)?;
+                let (val, overflowed, _ty) = self.overflowing_binary_op(bin_op, l, r)?;
                 if overflowed {
                     let layout = self.layout_of(substs.type_at(0))?;
                     let r_val = r.to_scalar()?.to_bits(layout.size)?;
diff --git a/src/librustc_mir/interpret/machine.rs b/src/librustc_mir/interpret/machine.rs
index 33ffb1d..bb74a50 100644
--- a/src/librustc_mir/interpret/machine.rs
+++ b/src/librustc_mir/interpret/machine.rs
@@ -7,7 +7,7 @@
 
 use rustc::hir::def_id::DefId;
 use rustc::mir;
-use rustc::ty::{self, TyCtxt};
+use rustc::ty::{self, Ty, TyCtxt};
 
 use super::{
     Allocation, AllocId, InterpResult, Scalar, AllocationExtra,
@@ -176,7 +176,7 @@
         bin_op: mir::BinOp,
         left: ImmTy<'tcx, Self::PointerTag>,
         right: ImmTy<'tcx, Self::PointerTag>,
-    ) -> InterpResult<'tcx, (Scalar<Self::PointerTag>, bool)>;
+    ) -> InterpResult<'tcx, (Scalar<Self::PointerTag>, bool, Ty<'tcx>)>;
 
     /// Heap allocations via the `box` keyword.
     fn box_alloc(
diff --git a/src/librustc_mir/interpret/operand.rs b/src/librustc_mir/interpret/operand.rs
index 139a92c..726ae6f 100644
--- a/src/librustc_mir/interpret/operand.rs
+++ b/src/librustc_mir/interpret/operand.rs
@@ -108,7 +108,7 @@
 // as input for binary and cast operations.
 #[derive(Copy, Clone, Debug)]
 pub struct ImmTy<'tcx, Tag=()> {
-    pub imm: Immediate<Tag>,
+    pub(crate) imm: Immediate<Tag>,
     pub layout: TyLayout<'tcx>,
 }
 
@@ -155,7 +155,7 @@
 
 #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
 pub struct OpTy<'tcx, Tag=()> {
-    op: Operand<Tag>,
+    op: Operand<Tag>, // Keep this private, it helps enforce invariants
     pub layout: TyLayout<'tcx>,
 }
 
@@ -187,14 +187,23 @@
     }
 }
 
-impl<'tcx, Tag: Copy> ImmTy<'tcx, Tag>
-{
+impl<'tcx, Tag: Copy> ImmTy<'tcx, Tag> {
     #[inline]
     pub fn from_scalar(val: Scalar<Tag>, layout: TyLayout<'tcx>) -> Self {
         ImmTy { imm: val.into(), layout }
     }
 
     #[inline]
+    pub fn from_uint(i: impl Into<u128>, layout: TyLayout<'tcx>) -> Self {
+        Self::from_scalar(Scalar::from_uint(i, layout.size), layout)
+    }
+
+    #[inline]
+    pub fn from_int(i: impl Into<i128>, layout: TyLayout<'tcx>) -> Self {
+        Self::from_scalar(Scalar::from_int(i, layout.size), layout)
+    }
+
+    #[inline]
     pub fn to_bits(self) -> InterpResult<'tcx, u128> {
         self.to_scalar()?.to_bits(self.layout.size)
     }
diff --git a/src/librustc_mir/interpret/operator.rs b/src/librustc_mir/interpret/operator.rs
index e638ebc..470cc93 100644
--- a/src/librustc_mir/interpret/operator.rs
+++ b/src/librustc_mir/interpret/operator.rs
@@ -1,5 +1,5 @@
 use rustc::mir;
-use rustc::ty::{self, layout::TyLayout};
+use rustc::ty::{self, Ty, layout::{TyLayout, LayoutOf}};
 use syntax::ast::FloatTy;
 use rustc_apfloat::Float;
 use rustc::mir::interpret::{InterpResult, Scalar};
@@ -17,7 +17,12 @@
         right: ImmTy<'tcx, M::PointerTag>,
         dest: PlaceTy<'tcx, M::PointerTag>,
     ) -> InterpResult<'tcx> {
-        let (val, overflowed) = self.binary_op(op, left, right)?;
+        let (val, overflowed, ty) = self.overflowing_binary_op(op, left, right)?;
+        debug_assert_eq!(
+            self.tcx.intern_tup(&[ty, self.tcx.types.bool]),
+            dest.layout.ty,
+            "type mismatch for result of {:?}", op,
+        );
         let val = Immediate::ScalarPair(val.into(), Scalar::from_bool(overflowed).into());
         self.write_immediate(val, dest)
     }
@@ -31,7 +36,8 @@
         right: ImmTy<'tcx, M::PointerTag>,
         dest: PlaceTy<'tcx, M::PointerTag>,
     ) -> InterpResult<'tcx> {
-        let (val, _overflowed) = self.binary_op(op, left, right)?;
+        let (val, _overflowed, ty) = self.overflowing_binary_op(op, left, right)?;
+        assert_eq!(ty, dest.layout.ty, "type mismatch for result of {:?}", op);
         self.write_scalar(val, dest)
     }
 }
@@ -42,7 +48,7 @@
         bin_op: mir::BinOp,
         l: char,
         r: char,
-    ) -> (Scalar<M::PointerTag>, bool) {
+    ) -> (Scalar<M::PointerTag>, bool, Ty<'tcx>) {
         use rustc::mir::BinOp::*;
 
         let res = match bin_op {
@@ -54,7 +60,7 @@
             Ge => l >= r,
             _ => bug!("Invalid operation on char: {:?}", bin_op),
         };
-        return (Scalar::from_bool(res), false);
+        return (Scalar::from_bool(res), false, self.tcx.types.bool);
     }
 
     fn binary_bool_op(
@@ -62,7 +68,7 @@
         bin_op: mir::BinOp,
         l: bool,
         r: bool,
-    ) -> (Scalar<M::PointerTag>, bool) {
+    ) -> (Scalar<M::PointerTag>, bool, Ty<'tcx>) {
         use rustc::mir::BinOp::*;
 
         let res = match bin_op {
@@ -77,32 +83,33 @@
             BitXor => l ^ r,
             _ => bug!("Invalid operation on bool: {:?}", bin_op),
         };
-        return (Scalar::from_bool(res), false);
+        return (Scalar::from_bool(res), false, self.tcx.types.bool);
     }
 
     fn binary_float_op<F: Float + Into<Scalar<M::PointerTag>>>(
         &self,
         bin_op: mir::BinOp,
+        ty: Ty<'tcx>,
         l: F,
         r: F,
-    ) -> (Scalar<M::PointerTag>, bool) {
+    ) -> (Scalar<M::PointerTag>, bool, Ty<'tcx>) {
         use rustc::mir::BinOp::*;
 
-        let val = match bin_op {
-            Eq => Scalar::from_bool(l == r),
-            Ne => Scalar::from_bool(l != r),
-            Lt => Scalar::from_bool(l < r),
-            Le => Scalar::from_bool(l <= r),
-            Gt => Scalar::from_bool(l > r),
-            Ge => Scalar::from_bool(l >= r),
-            Add => (l + r).value.into(),
-            Sub => (l - r).value.into(),
-            Mul => (l * r).value.into(),
-            Div => (l / r).value.into(),
-            Rem => (l % r).value.into(),
+        let (val, ty) = match bin_op {
+            Eq => (Scalar::from_bool(l == r), self.tcx.types.bool),
+            Ne => (Scalar::from_bool(l != r), self.tcx.types.bool),
+            Lt => (Scalar::from_bool(l < r), self.tcx.types.bool),
+            Le => (Scalar::from_bool(l <= r), self.tcx.types.bool),
+            Gt => (Scalar::from_bool(l > r), self.tcx.types.bool),
+            Ge => (Scalar::from_bool(l >= r), self.tcx.types.bool),
+            Add => ((l + r).value.into(), ty),
+            Sub => ((l - r).value.into(), ty),
+            Mul => ((l * r).value.into(), ty),
+            Div => ((l / r).value.into(), ty),
+            Rem => ((l % r).value.into(), ty),
             _ => bug!("invalid float op: `{:?}`", bin_op),
         };
-        return (val, false);
+        return (val, false, ty);
     }
 
     fn binary_int_op(
@@ -113,7 +120,7 @@
         left_layout: TyLayout<'tcx>,
         r: u128,
         right_layout: TyLayout<'tcx>,
-    ) -> InterpResult<'tcx, (Scalar<M::PointerTag>, bool)> {
+    ) -> InterpResult<'tcx, (Scalar<M::PointerTag>, bool, Ty<'tcx>)> {
         use rustc::mir::BinOp::*;
 
         // Shift ops can have an RHS with a different numeric type.
@@ -142,7 +149,7 @@
                 }
             };
             let truncated = self.truncate(result, left_layout);
-            return Ok((Scalar::from_uint(truncated, size), oflo));
+            return Ok((Scalar::from_uint(truncated, size), oflo, left_layout.ty));
         }
 
         // For the remaining ops, the types must be the same on both sides
@@ -167,7 +174,7 @@
             if let Some(op) = op {
                 let l = self.sign_extend(l, left_layout) as i128;
                 let r = self.sign_extend(r, right_layout) as i128;
-                return Ok((Scalar::from_bool(op(&l, &r)), false));
+                return Ok((Scalar::from_bool(op(&l, &r)), false, self.tcx.types.bool));
             }
             let op: Option<fn(i128, i128) -> (i128, bool)> = match bin_op {
                 Div if r == 0 => throw_panic!(DivisionByZero),
@@ -187,7 +194,7 @@
                     Rem | Div => {
                         // int_min / -1
                         if r == -1 && l == (1 << (size.bits() - 1)) {
-                            return Ok((Scalar::from_uint(l, size), true));
+                            return Ok((Scalar::from_uint(l, size), true, left_layout.ty));
                         }
                     },
                     _ => {},
@@ -202,25 +209,24 @@
                 // this may be out-of-bounds for the result type, so we have to truncate ourselves
                 let result = result as u128;
                 let truncated = self.truncate(result, left_layout);
-                return Ok((Scalar::from_uint(truncated, size), oflo));
+                return Ok((Scalar::from_uint(truncated, size), oflo, left_layout.ty));
             }
         }
 
         let size = left_layout.size;
 
-        // only ints left
-        let val = match bin_op {
-            Eq => Scalar::from_bool(l == r),
-            Ne => Scalar::from_bool(l != r),
+        let (val, ty) = match bin_op {
+            Eq => (Scalar::from_bool(l == r), self.tcx.types.bool),
+            Ne => (Scalar::from_bool(l != r), self.tcx.types.bool),
 
-            Lt => Scalar::from_bool(l < r),
-            Le => Scalar::from_bool(l <= r),
-            Gt => Scalar::from_bool(l > r),
-            Ge => Scalar::from_bool(l >= r),
+            Lt => (Scalar::from_bool(l < r), self.tcx.types.bool),
+            Le => (Scalar::from_bool(l <= r), self.tcx.types.bool),
+            Gt => (Scalar::from_bool(l > r), self.tcx.types.bool),
+            Ge => (Scalar::from_bool(l >= r), self.tcx.types.bool),
 
-            BitOr => Scalar::from_uint(l | r, size),
-            BitAnd => Scalar::from_uint(l & r, size),
-            BitXor => Scalar::from_uint(l ^ r, size),
+            BitOr => (Scalar::from_uint(l | r, size), left_layout.ty),
+            BitAnd => (Scalar::from_uint(l & r, size), left_layout.ty),
+            BitXor => (Scalar::from_uint(l ^ r, size), left_layout.ty),
 
             Add | Sub | Mul | Rem | Div => {
                 debug_assert!(!left_layout.abi.is_signed());
@@ -236,7 +242,11 @@
                 };
                 let (result, oflo) = op(l, r);
                 let truncated = self.truncate(result, left_layout);
-                return Ok((Scalar::from_uint(truncated, size), oflo || truncated != result));
+                return Ok((
+                    Scalar::from_uint(truncated, size),
+                    oflo || truncated != result,
+                    left_layout.ty,
+                ));
             }
 
             _ => {
@@ -250,17 +260,17 @@
             }
         };
 
-        Ok((val, false))
+        Ok((val, false, ty))
     }
 
-    /// Returns the result of the specified operation and whether it overflowed.
-    #[inline]
-    pub fn binary_op(
+    /// Returns the result of the specified operation, whether it overflowed, and
+    /// the result type.
+    pub fn overflowing_binary_op(
         &self,
         bin_op: mir::BinOp,
         left: ImmTy<'tcx, M::PointerTag>,
         right: ImmTy<'tcx, M::PointerTag>,
-    ) -> InterpResult<'tcx, (Scalar<M::PointerTag>, bool)> {
+    ) -> InterpResult<'tcx, (Scalar<M::PointerTag>, bool, Ty<'tcx>)> {
         trace!("Running binary op {:?}: {:?} ({:?}), {:?} ({:?})",
             bin_op, *left, left.layout.ty, *right, right.layout.ty);
 
@@ -279,11 +289,14 @@
             }
             ty::Float(fty) => {
                 assert_eq!(left.layout.ty, right.layout.ty);
+                let ty = left.layout.ty;
                 let left = left.to_scalar()?;
                 let right = right.to_scalar()?;
                 Ok(match fty {
-                    FloatTy::F32 => self.binary_float_op(bin_op, left.to_f32()?, right.to_f32()?),
-                    FloatTy::F64 => self.binary_float_op(bin_op, left.to_f64()?, right.to_f64()?),
+                    FloatTy::F32 =>
+                        self.binary_float_op(bin_op, ty, left.to_f32()?, right.to_f32()?),
+                    FloatTy::F64 =>
+                        self.binary_float_op(bin_op, ty, left.to_f64()?, right.to_f64()?),
                 })
             }
             _ if left.layout.ty.is_integral() => {
@@ -312,11 +325,23 @@
         }
     }
 
+    /// Typed version of `checked_binary_op`, returning an `ImmTy`. Also ignores overflows.
+    #[inline]
+    pub fn binary_op(
+        &self,
+        bin_op: mir::BinOp,
+        left: ImmTy<'tcx, M::PointerTag>,
+        right: ImmTy<'tcx, M::PointerTag>,
+    ) -> InterpResult<'tcx, ImmTy<'tcx, M::PointerTag>> {
+        let (val, _overflow, ty) = self.overflowing_binary_op(bin_op, left, right)?;
+        Ok(ImmTy::from_scalar(val, self.layout_of(ty)?))
+    }
+
     pub fn unary_op(
         &self,
         un_op: mir::UnOp,
         val: ImmTy<'tcx, M::PointerTag>,
-    ) -> InterpResult<'tcx, Scalar<M::PointerTag>> {
+    ) -> InterpResult<'tcx, ImmTy<'tcx, M::PointerTag>> {
         use rustc::mir::UnOp::*;
 
         let layout = val.layout;
@@ -330,7 +355,7 @@
                     Not => !val,
                     _ => bug!("Invalid bool op {:?}", un_op)
                 };
-                Ok(Scalar::from_bool(res))
+                Ok(ImmTy::from_scalar(Scalar::from_bool(res), self.layout_of(self.tcx.types.bool)?))
             }
             ty::Float(fty) => {
                 let res = match (un_op, fty) {
@@ -338,7 +363,7 @@
                     (Neg, FloatTy::F64) => Scalar::from_f64(-val.to_f64()?),
                     _ => bug!("Invalid float op {:?}", un_op)
                 };
-                Ok(res)
+                Ok(ImmTy::from_scalar(res, layout))
             }
             _ => {
                 assert!(layout.ty.is_integral());
@@ -351,7 +376,7 @@
                     }
                 };
                 // res needs tuncating
-                Ok(Scalar::from_uint(self.truncate(res, layout), layout.size))
+                Ok(ImmTy::from_uint(self.truncate(res, layout), layout))
             }
         }
     }
diff --git a/src/librustc_mir/interpret/place.rs b/src/librustc_mir/interpret/place.rs
index 16686c3..ef9f20d 100644
--- a/src/librustc_mir/interpret/place.rs
+++ b/src/librustc_mir/interpret/place.rs
@@ -45,7 +45,7 @@
 
 #[derive(Copy, Clone, Debug)]
 pub struct PlaceTy<'tcx, Tag=()> {
-    place: Place<Tag>,
+    place: Place<Tag>, // Keep this private, it helps enforce invariants
     pub layout: TyLayout<'tcx>,
 }
 
diff --git a/src/librustc_mir/interpret/step.rs b/src/librustc_mir/interpret/step.rs
index e558278..b010bf0 100644
--- a/src/librustc_mir/interpret/step.rs
+++ b/src/librustc_mir/interpret/step.rs
@@ -177,7 +177,8 @@
                 // The operand always has the same type as the result.
                 let val = self.read_immediate(self.eval_operand(operand, Some(dest.layout))?)?;
                 let val = self.unary_op(un_op, val)?;
-                self.write_scalar(val, dest)?;
+                assert_eq!(val.layout, dest.layout, "layout mismatch for result of {:?}", un_op);
+                self.write_immediate(*val, dest)?;
             }
 
             Aggregate(ref kind, ref operands) => {
diff --git a/src/librustc_mir/interpret/terminator.rs b/src/librustc_mir/interpret/terminator.rs
index 1d6b48e..5de2979 100644
--- a/src/librustc_mir/interpret/terminator.rs
+++ b/src/librustc_mir/interpret/terminator.rs
@@ -7,7 +7,7 @@
 use rustc_target::spec::abi::Abi;
 
 use super::{
-    InterpResult, PointerArithmetic, Scalar,
+    InterpResult, PointerArithmetic,
     InterpCx, Machine, OpTy, ImmTy, PlaceTy, MPlaceTy, StackPopCleanup, FnVal,
 };
 
@@ -50,11 +50,10 @@
 
                 for (index, &const_int) in values.iter().enumerate() {
                     // Compare using binary_op, to also support pointer values
-                    let const_int = Scalar::from_uint(const_int, discr.layout.size);
-                    let (res, _) = self.binary_op(mir::BinOp::Eq,
+                    let res = self.overflowing_binary_op(mir::BinOp::Eq,
                         discr,
-                        ImmTy::from_scalar(const_int, discr.layout),
-                    )?;
+                        ImmTy::from_uint(const_int, discr.layout),
+                    )?.0;
                     if res.to_bool()? {
                         target_block = targets[index];
                         break;
diff --git a/src/librustc_mir/transform/const_prop.rs b/src/librustc_mir/transform/const_prop.rs
index c3c432d..98d8ca5 100644
--- a/src/librustc_mir/transform/const_prop.rs
+++ b/src/librustc_mir/transform/const_prop.rs
@@ -19,7 +19,7 @@
 use rustc::ty::subst::InternalSubsts;
 use rustc_data_structures::indexed_vec::IndexVec;
 use rustc::ty::layout::{
-    LayoutOf, TyLayout, LayoutError, HasTyCtxt, TargetDataLayout, HasDataLayout, Size,
+    LayoutOf, TyLayout, LayoutError, HasTyCtxt, TargetDataLayout, HasDataLayout,
 };
 
 use crate::interpret::{
@@ -396,17 +396,10 @@
                 if let ty::Slice(_) = mplace.layout.ty.sty {
                     let len = mplace.meta.unwrap().to_usize(&self.ecx).unwrap();
 
-                    Some(ImmTy {
-                        imm: Immediate::Scalar(
-                            Scalar::from_uint(
-                                len,
-                                Size::from_bits(
-                                    self.tcx.sess.target.usize_ty.bit_width().unwrap() as u64
-                                )
-                            ).into(),
-                        ),
-                        layout: self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).ok()?,
-                    }.into())
+                    Some(ImmTy::from_uint(
+                        len,
+                        self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).ok()?,
+                    ).into())
                 } else {
                     trace!("not slice: {:?}", mplace.layout.ty.sty);
                     None
@@ -414,12 +407,10 @@
             },
             Rvalue::NullaryOp(NullOp::SizeOf, ty) => {
                 type_size_of(self.tcx, self.param_env, ty).and_then(|n| Some(
-                    ImmTy {
-                        imm: Immediate::Scalar(
-                            Scalar::from_uint(n, self.tcx.data_layout.pointer_size).into()
-                        ),
-                        layout: self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).ok()?,
-                    }.into()
+                    ImmTy::from_uint(
+                        n,
+                        self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).ok()?,
+                    ).into()
                 ))
             }
             Rvalue::UnaryOp(op, ref arg) => {
@@ -452,11 +443,7 @@
                     // Now run the actual operation.
                     this.ecx.unary_op(op, prim)
                 })?;
-                let res = ImmTy {
-                    imm: Immediate::Scalar(val.into()),
-                    layout: place_layout,
-                };
-                Some(res.into())
+                Some(val.into())
             }
             Rvalue::CheckedBinaryOp(op, ref left, ref right) |
             Rvalue::BinaryOp(op, ref left, ref right) => {
@@ -510,8 +497,8 @@
                     this.ecx.read_immediate(left)
                 })?;
                 trace!("const evaluating {:?} for {:?} and {:?}", op, left, right);
-                let (val, overflow) = self.use_ecx(source_info, |this| {
-                    this.ecx.binary_op(op, l, r)
+                let (val, overflow, _ty) = self.use_ecx(source_info, |this| {
+                    this.ecx.overflowing_binary_op(op, l, r)
                 })?;
                 let val = if let Rvalue::CheckedBinaryOp(..) = *rvalue {
                     Immediate::ScalarPair(