blob: 7836419e8b7516a719c15be6ad9e74ce622a563e [file] [log] [blame]
//! Fold impls for the next-trait-solver.
use rustc_type_ir::{
BoundVarIndexKind, DebruijnIndex, RegionKind, TypeFoldable, TypeFolder, TypeSuperFoldable,
TypeVisitableExt, inherent::IntoKind,
};
use crate::next_solver::{BoundConst, FxIndexMap};
use super::{
Binder, BoundRegion, BoundTy, Const, ConstKind, DbInterner, Predicate, Region, Ty, TyKind,
};
/// A delegate used when instantiating bound vars.
///
/// Any implementation must make sure that each bound variable always
/// gets mapped to the same result. `BoundVarReplacer` caches by using
/// a `DelayedMap` which does not cache the first few types it encounters.
pub trait BoundVarReplacerDelegate<'db> {
fn replace_region(&mut self, br: BoundRegion) -> Region<'db>;
fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>;
fn replace_const(&mut self, bv: BoundConst) -> Const<'db>;
}
/// A simple delegate taking 3 mutable functions. The used functions must
/// always return the same result for each bound variable, no matter how
/// frequently they are called.
pub struct FnMutDelegate<'db, 'a> {
pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a),
pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a),
pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a),
}
impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> {
fn replace_region(&mut self, br: BoundRegion) -> Region<'db> {
(self.regions)(br)
}
fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> {
(self.types)(bt)
}
fn replace_const(&mut self, bv: BoundConst) -> Const<'db> {
(self.consts)(bv)
}
}
/// Replaces the escaping bound vars (late bound regions or bound types) in a type.
pub(crate) struct BoundVarReplacer<'db, D> {
interner: DbInterner<'db>,
/// As with `RegionFolder`, represents the index of a binder *just outside*
/// the ones we have visited.
current_index: DebruijnIndex,
delegate: D,
}
impl<'db, D: BoundVarReplacerDelegate<'db>> BoundVarReplacer<'db, D> {
pub(crate) fn new(tcx: DbInterner<'db>, delegate: D) -> Self {
BoundVarReplacer { interner: tcx, current_index: DebruijnIndex::ZERO, delegate }
}
}
impl<'db, D> TypeFolder<DbInterner<'db>> for BoundVarReplacer<'db, D>
where
D: BoundVarReplacerDelegate<'db>,
{
fn cx(&self) -> DbInterner<'db> {
self.interner
}
fn fold_binder<T: TypeFoldable<DbInterner<'db>>>(
&mut self,
t: Binder<'db, T>,
) -> Binder<'db, T> {
self.current_index.shift_in(1);
let t = t.super_fold_with(self);
self.current_index.shift_out(1);
t
}
fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
match t.kind() {
TyKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_ty)
if debruijn == self.current_index =>
{
let ty = self.delegate.replace_ty(bound_ty);
debug_assert!(!ty.has_vars_bound_above(DebruijnIndex::ZERO));
rustc_type_ir::shift_vars(self.interner, ty, self.current_index.as_u32())
}
_ => {
if !t.has_vars_bound_at_or_above(self.current_index) {
t
} else {
t.super_fold_with(self)
}
}
}
}
fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
match r.kind() {
RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn), br)
if debruijn == self.current_index =>
{
let region = self.delegate.replace_region(br);
if let RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn1), br) = region.kind()
{
// If the callback returns a bound region,
// that region should always use the INNERMOST
// debruijn index. Then we adjust it to the
// correct depth.
assert_eq!(debruijn1, DebruijnIndex::ZERO);
Region::new_bound(self.interner, debruijn, br)
} else {
region
}
}
_ => r,
}
}
fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
match ct.kind() {
ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_const)
if debruijn == self.current_index =>
{
let ct = self.delegate.replace_const(bound_const);
debug_assert!(!ct.has_vars_bound_above(DebruijnIndex::ZERO));
rustc_type_ir::shift_vars(self.interner, ct, self.current_index.as_u32())
}
_ => ct.super_fold_with(self),
}
}
fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
}
}
pub fn fold_tys<'db, T: TypeFoldable<DbInterner<'db>>>(
interner: DbInterner<'db>,
t: T,
callback: impl FnMut(Ty<'db>) -> Ty<'db>,
) -> T {
struct Folder<'db, F> {
interner: DbInterner<'db>,
callback: F,
}
impl<'db, F: FnMut(Ty<'db>) -> Ty<'db>> TypeFolder<DbInterner<'db>> for Folder<'db, F> {
fn cx(&self) -> DbInterner<'db> {
self.interner
}
fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
let t = t.super_fold_with(self);
(self.callback)(t)
}
}
t.fold_with(&mut Folder { interner, callback })
}
impl<'db> DbInterner<'db> {
/// Replaces all regions bound by the given `Binder` with the
/// results returned by the closure; the closure is expected to
/// return a free region (relative to this binder), and hence the
/// binder is removed in the return type. The closure is invoked
/// once for each unique `BoundRegionKind`; multiple references to the
/// same `BoundRegionKind` will reuse the previous result. A map is
/// returned at the end with each bound region and the free region
/// that replaced it.
///
/// # Panics
///
/// This method only replaces late bound regions. Any types or
/// constants bound by `value` will cause an ICE.
pub fn instantiate_bound_regions<T, F>(
self,
value: Binder<'db, T>,
mut fld_r: F,
) -> (T, FxIndexMap<BoundRegion, Region<'db>>)
where
F: FnMut(BoundRegion) -> Region<'db>,
T: TypeFoldable<DbInterner<'db>>,
{
let mut region_map = FxIndexMap::default();
let real_fld_r = |br: BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br));
let value = self.instantiate_bound_regions_uncached(value, real_fld_r);
(value, region_map)
}
pub fn instantiate_bound_regions_uncached<T, F>(
self,
value: Binder<'db, T>,
mut replace_regions: F,
) -> T
where
F: FnMut(BoundRegion) -> Region<'db>,
T: TypeFoldable<DbInterner<'db>>,
{
let value = value.skip_binder();
if !value.has_escaping_bound_vars() {
value
} else {
let delegate = FnMutDelegate {
regions: &mut replace_regions,
types: &mut |b| panic!("unexpected bound ty in binder: {b:?}"),
consts: &mut |b| panic!("unexpected bound ct in binder: {b:?}"),
};
let mut replacer = BoundVarReplacer::new(self, delegate);
value.fold_with(&mut replacer)
}
}
/// Replaces any late-bound regions bound in `value` with `'erased`. Useful in codegen but also
/// method lookup and a few other places where precise region relationships are not required.
pub fn instantiate_bound_regions_with_erased<T>(self, value: Binder<'db, T>) -> T
where
T: TypeFoldable<DbInterner<'db>>,
{
self.instantiate_bound_regions(value, |_| Region::new_erased(self)).0
}
}