| //! Fold impls for the next-trait-solver. |
| |
| use rustc_type_ir::{ |
| BoundVarIndexKind, DebruijnIndex, RegionKind, TypeFoldable, TypeFolder, TypeSuperFoldable, |
| TypeVisitableExt, inherent::IntoKind, |
| }; |
| |
| use crate::next_solver::{BoundConst, FxIndexMap}; |
| |
| use super::{ |
| Binder, BoundRegion, BoundTy, Const, ConstKind, DbInterner, Predicate, Region, Ty, TyKind, |
| }; |
| |
| /// A delegate used when instantiating bound vars. |
| /// |
| /// Any implementation must make sure that each bound variable always |
| /// gets mapped to the same result. `BoundVarReplacer` caches by using |
| /// a `DelayedMap` which does not cache the first few types it encounters. |
| pub trait BoundVarReplacerDelegate<'db> { |
| fn replace_region(&mut self, br: BoundRegion) -> Region<'db>; |
| fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>; |
| fn replace_const(&mut self, bv: BoundConst) -> Const<'db>; |
| } |
| |
| /// A simple delegate taking 3 mutable functions. The used functions must |
| /// always return the same result for each bound variable, no matter how |
| /// frequently they are called. |
| pub struct FnMutDelegate<'db, 'a> { |
| pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a), |
| pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a), |
| pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a), |
| } |
| |
| impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> { |
| fn replace_region(&mut self, br: BoundRegion) -> Region<'db> { |
| (self.regions)(br) |
| } |
| fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> { |
| (self.types)(bt) |
| } |
| fn replace_const(&mut self, bv: BoundConst) -> Const<'db> { |
| (self.consts)(bv) |
| } |
| } |
| |
| /// Replaces the escaping bound vars (late bound regions or bound types) in a type. |
| pub(crate) struct BoundVarReplacer<'db, D> { |
| interner: DbInterner<'db>, |
| /// As with `RegionFolder`, represents the index of a binder *just outside* |
| /// the ones we have visited. |
| current_index: DebruijnIndex, |
| |
| delegate: D, |
| } |
| |
| impl<'db, D: BoundVarReplacerDelegate<'db>> BoundVarReplacer<'db, D> { |
| pub(crate) fn new(tcx: DbInterner<'db>, delegate: D) -> Self { |
| BoundVarReplacer { interner: tcx, current_index: DebruijnIndex::ZERO, delegate } |
| } |
| } |
| |
| impl<'db, D> TypeFolder<DbInterner<'db>> for BoundVarReplacer<'db, D> |
| where |
| D: BoundVarReplacerDelegate<'db>, |
| { |
| fn cx(&self) -> DbInterner<'db> { |
| self.interner |
| } |
| |
| fn fold_binder<T: TypeFoldable<DbInterner<'db>>>( |
| &mut self, |
| t: Binder<'db, T>, |
| ) -> Binder<'db, T> { |
| self.current_index.shift_in(1); |
| let t = t.super_fold_with(self); |
| self.current_index.shift_out(1); |
| t |
| } |
| |
| fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> { |
| match t.kind() { |
| TyKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_ty) |
| if debruijn == self.current_index => |
| { |
| let ty = self.delegate.replace_ty(bound_ty); |
| debug_assert!(!ty.has_vars_bound_above(DebruijnIndex::ZERO)); |
| rustc_type_ir::shift_vars(self.interner, ty, self.current_index.as_u32()) |
| } |
| _ => { |
| if !t.has_vars_bound_at_or_above(self.current_index) { |
| t |
| } else { |
| t.super_fold_with(self) |
| } |
| } |
| } |
| } |
| |
| fn fold_region(&mut self, r: Region<'db>) -> Region<'db> { |
| match r.kind() { |
| RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn), br) |
| if debruijn == self.current_index => |
| { |
| let region = self.delegate.replace_region(br); |
| if let RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn1), br) = region.kind() |
| { |
| // If the callback returns a bound region, |
| // that region should always use the INNERMOST |
| // debruijn index. Then we adjust it to the |
| // correct depth. |
| assert_eq!(debruijn1, DebruijnIndex::ZERO); |
| Region::new_bound(self.interner, debruijn, br) |
| } else { |
| region |
| } |
| } |
| _ => r, |
| } |
| } |
| |
| fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> { |
| match ct.kind() { |
| ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_const) |
| if debruijn == self.current_index => |
| { |
| let ct = self.delegate.replace_const(bound_const); |
| debug_assert!(!ct.has_vars_bound_above(DebruijnIndex::ZERO)); |
| rustc_type_ir::shift_vars(self.interner, ct, self.current_index.as_u32()) |
| } |
| _ => ct.super_fold_with(self), |
| } |
| } |
| |
| fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> { |
| if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p } |
| } |
| } |
| |
| pub fn fold_tys<'db, T: TypeFoldable<DbInterner<'db>>>( |
| interner: DbInterner<'db>, |
| t: T, |
| callback: impl FnMut(Ty<'db>) -> Ty<'db>, |
| ) -> T { |
| struct Folder<'db, F> { |
| interner: DbInterner<'db>, |
| callback: F, |
| } |
| impl<'db, F: FnMut(Ty<'db>) -> Ty<'db>> TypeFolder<DbInterner<'db>> for Folder<'db, F> { |
| fn cx(&self) -> DbInterner<'db> { |
| self.interner |
| } |
| |
| fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> { |
| let t = t.super_fold_with(self); |
| (self.callback)(t) |
| } |
| } |
| |
| t.fold_with(&mut Folder { interner, callback }) |
| } |
| |
| impl<'db> DbInterner<'db> { |
| /// Replaces all regions bound by the given `Binder` with the |
| /// results returned by the closure; the closure is expected to |
| /// return a free region (relative to this binder), and hence the |
| /// binder is removed in the return type. The closure is invoked |
| /// once for each unique `BoundRegionKind`; multiple references to the |
| /// same `BoundRegionKind` will reuse the previous result. A map is |
| /// returned at the end with each bound region and the free region |
| /// that replaced it. |
| /// |
| /// # Panics |
| /// |
| /// This method only replaces late bound regions. Any types or |
| /// constants bound by `value` will cause an ICE. |
| pub fn instantiate_bound_regions<T, F>( |
| self, |
| value: Binder<'db, T>, |
| mut fld_r: F, |
| ) -> (T, FxIndexMap<BoundRegion, Region<'db>>) |
| where |
| F: FnMut(BoundRegion) -> Region<'db>, |
| T: TypeFoldable<DbInterner<'db>>, |
| { |
| let mut region_map = FxIndexMap::default(); |
| let real_fld_r = |br: BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br)); |
| let value = self.instantiate_bound_regions_uncached(value, real_fld_r); |
| (value, region_map) |
| } |
| |
| pub fn instantiate_bound_regions_uncached<T, F>( |
| self, |
| value: Binder<'db, T>, |
| mut replace_regions: F, |
| ) -> T |
| where |
| F: FnMut(BoundRegion) -> Region<'db>, |
| T: TypeFoldable<DbInterner<'db>>, |
| { |
| let value = value.skip_binder(); |
| if !value.has_escaping_bound_vars() { |
| value |
| } else { |
| let delegate = FnMutDelegate { |
| regions: &mut replace_regions, |
| types: &mut |b| panic!("unexpected bound ty in binder: {b:?}"), |
| consts: &mut |b| panic!("unexpected bound ct in binder: {b:?}"), |
| }; |
| let mut replacer = BoundVarReplacer::new(self, delegate); |
| value.fold_with(&mut replacer) |
| } |
| } |
| |
| /// Replaces any late-bound regions bound in `value` with `'erased`. Useful in codegen but also |
| /// method lookup and a few other places where precise region relationships are not required. |
| pub fn instantiate_bound_regions_with_erased<T>(self, value: Binder<'db, T>) -> T |
| where |
| T: TypeFoldable<DbInterner<'db>>, |
| { |
| self.instantiate_bound_regions(value, |_| Region::new_erased(self)).0 |
| } |
| } |