Fix inference of `AsyncFnX` return type
diff --git a/crates/hir-ty/src/chalk_db.rs b/crates/hir-ty/src/chalk_db.rs
index cd799c0..22b96b5 100644
--- a/crates/hir-ty/src/chalk_db.rs
+++ b/crates/hir-ty/src/chalk_db.rs
@@ -259,7 +259,7 @@
}
fn well_known_trait_id(
&self,
- well_known_trait: rust_ir::WellKnownTrait,
+ well_known_trait: WellKnownTrait,
) -> Option<chalk_ir::TraitId<Interner>> {
let lang_attr = lang_item_from_well_known_trait(well_known_trait);
let trait_ = lang_attr.resolve_trait(self.db, self.krate)?;
diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs
index f0989d9..f210dd8 100644
--- a/crates/hir-ty/src/display.rs
+++ b/crates/hir-ty/src/display.rs
@@ -1463,6 +1463,8 @@
}
if f.closure_style == ClosureStyle::RANotation || !sig.ret().is_unit() {
write!(f, " -> ")?;
+ // FIXME: We display `AsyncFn` as `-> impl Future`, but this is hard to fix because
+ // we don't have a trait environment here, required to normalize `<Ret as Future>::Output`.
sig.ret().hir_fmt(f)?;
}
} else {
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 800897c..bd57ca8 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -38,7 +38,7 @@
infer::{BreakableKind, CoerceMany, Diverges, coerce::CoerceNever},
make_binders,
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
- to_chalk_trait_id,
+ to_assoc_type_id, to_chalk_trait_id,
traits::FnTrait,
utils::{self, elaborate_clause_supertraits},
};
@@ -245,7 +245,7 @@
}
fn deduce_closure_kind_from_predicate_clauses(
- &self,
+ &mut self,
expected_ty: &Ty,
clauses: impl DoubleEndedIterator<Item = WhereClause>,
closure_kind: ClosureKind,
@@ -378,7 +378,7 @@
}
fn deduce_sig_from_projection(
- &self,
+ &mut self,
closure_kind: ClosureKind,
projection_ty: &ProjectionTy,
projected_ty: &Ty,
@@ -392,13 +392,16 @@
// For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits,
// for closures and async closures, respectively.
- match closure_kind {
- ClosureKind::Closure | ClosureKind::Async
- if self.fn_trait_kind_from_trait_id(trait_).is_some() =>
- {
- self.extract_sig_from_projection(projection_ty, projected_ty)
- }
- _ => None,
+ let fn_trait_kind = self.fn_trait_kind_from_trait_id(trait_)?;
+ if !matches!(closure_kind, ClosureKind::Closure | ClosureKind::Async) {
+ return None;
+ }
+ if fn_trait_kind.is_async() {
+ // If the expected trait is `AsyncFn(...) -> X`, we don't know what the return type is,
+ // but we do know it must implement `Future<Output = X>`.
+ self.extract_async_fn_sig_from_projection(projection_ty, projected_ty)
+ } else {
+ self.extract_sig_from_projection(projection_ty, projected_ty)
}
}
@@ -424,6 +427,39 @@
)))
}
+ fn extract_async_fn_sig_from_projection(
+ &mut self,
+ projection_ty: &ProjectionTy,
+ projected_ty: &Ty,
+ ) -> Option<FnSubst<Interner>> {
+ let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner);
+
+ let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else {
+ return None;
+ };
+
+ let ret_param_future_output = projected_ty;
+ let ret_param_future = self.table.new_type_var();
+ let future_output =
+ LangItem::FutureOutput.resolve_type_alias(self.db, self.resolver.krate())?;
+ let future_projection = crate::AliasTy::Projection(crate::ProjectionTy {
+ associated_ty_id: to_assoc_type_id(future_output),
+ substitution: Substitution::from1(Interner, ret_param_future.clone()),
+ });
+ self.table.register_obligation(
+ crate::AliasEq { alias: future_projection, ty: ret_param_future_output.clone() }
+ .cast(Interner),
+ );
+
+ Some(FnSubst(Substitution::from_iter(
+ Interner,
+ input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new(
+ Interner,
+ chalk_ir::GenericArgData::Ty(ret_param_future),
+ ))),
+ )))
+ }
+
fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?)
}
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 2b527a4..e5d1fbe 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -4903,3 +4903,30 @@
"#]],
);
}
+
+#[test]
+fn async_fn_return_type() {
+ check_infer(
+ r#"
+//- minicore: async_fn
+fn foo<F: AsyncFn() -> R, R>(_: F) -> R {
+ loop {}
+}
+
+fn main() {
+ foo(async move || ());
+}
+ "#,
+ expect![[r#"
+ 29..30 '_': F
+ 40..55 '{ loop {} }': R
+ 46..53 'loop {}': !
+ 51..53 '{}': ()
+ 67..97 '{ ...()); }': ()
+ 73..76 'foo': fn foo<impl AsyncFn() -> impl Future<Output = ()>, ()>(impl AsyncFn() -> impl Future<Output = ()>)
+ 73..94 'foo(as...|| ())': ()
+ 77..93 'async ... || ()': impl AsyncFn() -> impl Future<Output = ()>
+ 91..93 '()': ()
+ "#]],
+ );
+}
diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs
index f9f8776..7414b4f 100644
--- a/crates/hir-ty/src/traits.rs
+++ b/crates/hir-ty/src/traits.rs
@@ -291,4 +291,9 @@
pub fn get_id(self, db: &dyn HirDatabase, krate: Crate) -> Option<TraitId> {
self.lang_item().resolve_trait(db, krate)
}
+
+ #[inline]
+ pub(crate) fn is_async(self) -> bool {
+ matches!(self, FnTrait::AsyncFn | FnTrait::AsyncFnMut | FnTrait::AsyncFnOnce)
+ }
}