blob: 2055c3151c6f61ae4d1b4e9a63c6571809d4e074 [file] [log] [blame]
//! Trait solving using next trait solver.
use core::fmt;
use std::hash::Hash;
use base_db::Crate;
use hir_def::{
AdtId, AssocItemId, BlockId, HasModule, ImplId, Lookup, TraitId,
lang_item::LangItem,
nameres::DefMap,
signatures::{ConstFlags, EnumFlags, FnFlags, StructFlags, TraitFlags, TypeAliasFlags},
};
use hir_expand::name::Name;
use intern::sym;
use rustc_next_trait_solver::solve::{HasChanged, SolverDelegateEvalExt};
use rustc_type_ir::{
TypingMode,
inherent::{AdtDef, BoundExistentialPredicates, IntoKind, Span as _},
solve::Certainty,
};
use triomphe::Arc;
use crate::{
db::HirDatabase,
next_solver::{
Canonical, DbInterner, GenericArgs, Goal, ParamEnv, Predicate, SolverContext, Span, Ty,
TyKind,
infer::{DbInternerInferExt, InferCtxt, traits::ObligationCause},
obligation_ctxt::ObligationCtxt,
},
};
/// A set of clauses that we assume to be true. E.g. if we are inside this function:
/// ```rust
/// fn foo<T: Default>(t: T) {}
/// ```
/// we assume that `T: Default`.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TraitEnvironment<'db> {
pub krate: Crate,
pub block: Option<BlockId>,
// FIXME make this a BTreeMap
traits_from_clauses: Box<[(Ty<'db>, TraitId)]>,
pub env: ParamEnv<'db>,
}
impl<'db> TraitEnvironment<'db> {
pub fn empty(krate: Crate) -> Arc<Self> {
Arc::new(TraitEnvironment {
krate,
block: None,
traits_from_clauses: Box::default(),
env: ParamEnv::empty(),
})
}
pub fn new(
krate: Crate,
block: Option<BlockId>,
traits_from_clauses: Box<[(Ty<'db>, TraitId)]>,
env: ParamEnv<'db>,
) -> Arc<Self> {
Arc::new(TraitEnvironment { krate, block, traits_from_clauses, env })
}
// pub fn with_block(self: &mut Arc<Self>, block: BlockId) {
pub fn with_block(this: &mut Arc<Self>, block: BlockId) {
Arc::make_mut(this).block = Some(block);
}
pub fn traits_in_scope_from_clauses(&self, ty: Ty<'db>) -> impl Iterator<Item = TraitId> + '_ {
self.traits_from_clauses
.iter()
.filter_map(move |(self_ty, trait_id)| (*self_ty == ty).then_some(*trait_id))
}
}
/// This should be used in `hir` only.
pub fn structurally_normalize_ty<'db>(
infcx: &InferCtxt<'db>,
ty: Ty<'db>,
env: Arc<TraitEnvironment<'db>>,
) -> Ty<'db> {
let TyKind::Alias(..) = ty.kind() else { return ty };
let mut ocx = ObligationCtxt::new(infcx);
let ty = ocx.structurally_normalize_ty(&ObligationCause::dummy(), env.env, ty).unwrap_or(ty);
ty.replace_infer_with_error(infcx.interner)
}
#[derive(Clone, Debug, PartialEq)]
pub enum NextTraitSolveResult {
Certain,
Uncertain,
NoSolution,
}
impl NextTraitSolveResult {
pub fn no_solution(&self) -> bool {
matches!(self, NextTraitSolveResult::NoSolution)
}
pub fn certain(&self) -> bool {
matches!(self, NextTraitSolveResult::Certain)
}
pub fn uncertain(&self) -> bool {
matches!(self, NextTraitSolveResult::Uncertain)
}
}
pub fn next_trait_solve_canonical_in_ctxt<'db>(
infer_ctxt: &InferCtxt<'db>,
goal: Canonical<'db, Goal<'db, Predicate<'db>>>,
) -> NextTraitSolveResult {
infer_ctxt.probe(|_| {
let context = <&SolverContext<'db>>::from(infer_ctxt);
tracing::info!(?goal);
let (goal, var_values) = context.instantiate_canonical(&goal);
tracing::info!(?var_values);
let res = context.evaluate_root_goal(goal, Span::dummy(), None);
let res = res.map(|r| (r.has_changed, r.certainty));
tracing::debug!("solve_nextsolver({:?}) => {:?}", goal, res);
match res {
Err(_) => NextTraitSolveResult::NoSolution,
Ok((_, Certainty::Yes)) => NextTraitSolveResult::Certain,
Ok((_, Certainty::Maybe { .. })) => NextTraitSolveResult::Uncertain,
}
})
}
/// Solve a trait goal using next trait solver.
pub fn next_trait_solve_in_ctxt<'db, 'a>(
infer_ctxt: &'a InferCtxt<'db>,
goal: Goal<'db, Predicate<'db>>,
) -> Result<(HasChanged, Certainty), rustc_type_ir::solve::NoSolution> {
tracing::info!(?goal);
let context = <&SolverContext<'db>>::from(infer_ctxt);
let res = context.evaluate_root_goal(goal, Span::dummy(), None);
let res = res.map(|r| (r.has_changed, r.certainty));
tracing::debug!("solve_nextsolver({:?}) => {:?}", goal, res);
res
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum FnTrait {
// Warning: Order is important. If something implements `x` it should also implement
// `y` if `y <= x`.
FnOnce,
FnMut,
Fn,
AsyncFnOnce,
AsyncFnMut,
AsyncFn,
}
impl fmt::Display for FnTrait {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FnTrait::FnOnce => write!(f, "FnOnce"),
FnTrait::FnMut => write!(f, "FnMut"),
FnTrait::Fn => write!(f, "Fn"),
FnTrait::AsyncFnOnce => write!(f, "AsyncFnOnce"),
FnTrait::AsyncFnMut => write!(f, "AsyncFnMut"),
FnTrait::AsyncFn => write!(f, "AsyncFn"),
}
}
}
impl FnTrait {
pub const fn function_name(&self) -> &'static str {
match self {
FnTrait::FnOnce => "call_once",
FnTrait::FnMut => "call_mut",
FnTrait::Fn => "call",
FnTrait::AsyncFnOnce => "async_call_once",
FnTrait::AsyncFnMut => "async_call_mut",
FnTrait::AsyncFn => "async_call",
}
}
const fn lang_item(self) -> LangItem {
match self {
FnTrait::FnOnce => LangItem::FnOnce,
FnTrait::FnMut => LangItem::FnMut,
FnTrait::Fn => LangItem::Fn,
FnTrait::AsyncFnOnce => LangItem::AsyncFnOnce,
FnTrait::AsyncFnMut => LangItem::AsyncFnMut,
FnTrait::AsyncFn => LangItem::AsyncFn,
}
}
pub const fn from_lang_item(lang_item: LangItem) -> Option<Self> {
match lang_item {
LangItem::FnOnce => Some(FnTrait::FnOnce),
LangItem::FnMut => Some(FnTrait::FnMut),
LangItem::Fn => Some(FnTrait::Fn),
LangItem::AsyncFnOnce => Some(FnTrait::AsyncFnOnce),
LangItem::AsyncFnMut => Some(FnTrait::AsyncFnMut),
LangItem::AsyncFn => Some(FnTrait::AsyncFn),
_ => None,
}
}
pub fn method_name(self) -> Name {
match self {
FnTrait::FnOnce => Name::new_symbol_root(sym::call_once),
FnTrait::FnMut => Name::new_symbol_root(sym::call_mut),
FnTrait::Fn => Name::new_symbol_root(sym::call),
FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once),
FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut),
FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call),
}
}
pub fn get_id(self, db: &dyn HirDatabase, krate: Crate) -> Option<TraitId> {
self.lang_item().resolve_trait(db, krate)
}
}
/// This should not be used in `hir-ty`, only in `hir`.
pub fn implements_trait_unique<'db>(
ty: Ty<'db>,
db: &'db dyn HirDatabase,
env: Arc<TraitEnvironment<'db>>,
trait_: TraitId,
) -> bool {
implements_trait_unique_impl(db, env, trait_, &mut |infcx| {
infcx.fill_rest_fresh_args(trait_.into(), [ty.into()])
})
}
/// This should not be used in `hir-ty`, only in `hir`.
pub fn implements_trait_unique_with_args<'db>(
db: &'db dyn HirDatabase,
env: Arc<TraitEnvironment<'db>>,
trait_: TraitId,
args: GenericArgs<'db>,
) -> bool {
implements_trait_unique_impl(db, env, trait_, &mut |_| args)
}
fn implements_trait_unique_impl<'db>(
db: &'db dyn HirDatabase,
env: Arc<TraitEnvironment<'db>>,
trait_: TraitId,
create_args: &mut dyn FnMut(&InferCtxt<'db>) -> GenericArgs<'db>,
) -> bool {
let interner = DbInterner::new_with(db, Some(env.krate), env.block);
// FIXME(next-solver): I believe this should be `PostAnalysis`.
let infcx = interner.infer_ctxt().build(TypingMode::non_body_analysis());
let args = create_args(&infcx);
let trait_ref = rustc_type_ir::TraitRef::new_from_args(interner, trait_.into(), args);
let goal = Goal::new(interner, env.env, trait_ref);
let result = crate::traits::next_trait_solve_in_ctxt(&infcx, goal);
matches!(result, Ok((_, Certainty::Yes)))
}
pub fn is_inherent_impl_coherent(db: &dyn HirDatabase, def_map: &DefMap, impl_id: ImplId) -> bool {
let self_ty = db.impl_self_ty(impl_id).instantiate_identity();
let self_ty = self_ty.kind();
let impl_allowed = match self_ty {
TyKind::Tuple(_)
| TyKind::FnDef(_, _)
| TyKind::Array(_, _)
| TyKind::Never
| TyKind::RawPtr(_, _)
| TyKind::Ref(_, _, _)
| TyKind::Slice(_)
| TyKind::Str
| TyKind::Bool
| TyKind::Char
| TyKind::Int(_)
| TyKind::Uint(_)
| TyKind::Float(_) => def_map.is_rustc_coherence_is_core(),
TyKind::Adt(adt_def, _) => adt_def.def_id().0.module(db).krate() == def_map.krate(),
TyKind::Dynamic(it, _) => it
.principal_def_id()
.is_some_and(|trait_id| trait_id.0.module(db).krate() == def_map.krate()),
_ => true,
};
impl_allowed || {
let rustc_has_incoherent_inherent_impls = match self_ty {
TyKind::Tuple(_)
| TyKind::FnDef(_, _)
| TyKind::Array(_, _)
| TyKind::Never
| TyKind::RawPtr(_, _)
| TyKind::Ref(_, _, _)
| TyKind::Slice(_)
| TyKind::Str
| TyKind::Bool
| TyKind::Char
| TyKind::Int(_)
| TyKind::Uint(_)
| TyKind::Float(_) => true,
TyKind::Adt(adt_def, _) => match adt_def.def_id().0 {
hir_def::AdtId::StructId(id) => db
.struct_signature(id)
.flags
.contains(StructFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS),
hir_def::AdtId::UnionId(id) => db
.union_signature(id)
.flags
.contains(StructFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS),
hir_def::AdtId::EnumId(it) => db
.enum_signature(it)
.flags
.contains(EnumFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS),
},
TyKind::Dynamic(it, _) => it.principal_def_id().is_some_and(|trait_id| {
db.trait_signature(trait_id.0)
.flags
.contains(TraitFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS)
}),
_ => false,
};
let items = impl_id.impl_items(db);
rustc_has_incoherent_inherent_impls
&& !items.items.is_empty()
&& items.items.iter().all(|&(_, assoc)| match assoc {
AssocItemId::FunctionId(it) => {
db.function_signature(it).flags.contains(FnFlags::RUSTC_ALLOW_INCOHERENT_IMPL)
}
AssocItemId::ConstId(it) => {
db.const_signature(it).flags.contains(ConstFlags::RUSTC_ALLOW_INCOHERENT_IMPL)
}
AssocItemId::TypeAliasId(it) => db
.type_alias_signature(it)
.flags
.contains(TypeAliasFlags::RUSTC_ALLOW_INCOHERENT_IMPL),
})
}
}
/// Checks whether the impl satisfies the orphan rules.
///
/// Given `impl<P1..=Pn> Trait<T1..=Tn> for T0`, an `impl`` is valid only if at least one of the following is true:
/// - Trait is a local trait
/// - All of
/// - At least one of the types `T0..=Tn`` must be a local type. Let `Ti`` be the first such type.
/// - No uncovered type parameters `P1..=Pn` may appear in `T0..Ti`` (excluding `Ti`)
pub fn check_orphan_rules<'db>(db: &'db dyn HirDatabase, impl_: ImplId) -> bool {
let Some(impl_trait) = db.impl_trait(impl_) else {
// not a trait impl
return true;
};
let local_crate = impl_.lookup(db).container.krate();
let is_local = |tgt_crate| tgt_crate == local_crate;
let trait_ref = impl_trait.instantiate_identity();
let trait_id = trait_ref.def_id.0;
if is_local(trait_id.module(db).krate()) {
// trait to be implemented is local
return true;
}
let unwrap_fundamental = |mut ty: Ty<'db>| {
// Unwrap all layers of fundamental types with a loop.
loop {
match ty.kind() {
TyKind::Ref(_, referenced, _) => ty = referenced,
TyKind::Adt(adt_def, subs) => {
let AdtId::StructId(s) = adt_def.def_id().0 else {
break ty;
};
let struct_signature = db.struct_signature(s);
if struct_signature.flags.contains(StructFlags::FUNDAMENTAL) {
let next = subs.types().next();
match next {
Some(it) => ty = it,
None => break ty,
}
} else {
break ty;
}
}
_ => break ty,
}
}
};
// - At least one of the types `T0..=Tn`` must be a local type. Let `Ti`` be the first such type.
// FIXME: param coverage
// - No uncovered type parameters `P1..=Pn` may appear in `T0..Ti`` (excluding `Ti`)
let is_not_orphan = trait_ref.args.types().any(|ty| match unwrap_fundamental(ty).kind() {
TyKind::Adt(adt_def, _) => is_local(adt_def.def_id().0.module(db).krate()),
TyKind::Error(_) => true,
TyKind::Dynamic(it, _) => {
it.principal_def_id().is_some_and(|trait_id| is_local(trait_id.0.module(db).krate()))
}
_ => false,
});
#[allow(clippy::let_and_return)]
is_not_orphan
}