Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement deferred_projection_equality for erica solver #107507

Merged
merged 4 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions compiler/rustc_borrowck/src/type_check/relate_tys.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_infer::infer::nll_relate::{NormalizationStrategy, TypeRelating, TypeRelatingDelegate};
use rustc_infer::infer::nll_relate::{TypeRelating, TypeRelatingDelegate};
use rustc_infer::infer::NllRegionVariableOrigin;
use rustc_infer::traits::PredicateObligations;
use rustc_middle::mir::ConstraintCategory;
Expand Down Expand Up @@ -140,10 +140,6 @@ impl<'tcx> TypeRelatingDelegate<'tcx> for NllTypeRelatingDelegate<'_, '_, 'tcx>
);
}

fn normalization() -> NormalizationStrategy {
NormalizationStrategy::Eager
}

fn forbid_inference_vars() -> bool {
true
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_analysis/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ty::Clause::RegionOutlives(_) => bug!(),
},
ty::PredicateKind::WellFormed(_)
| ty::PredicateKind::AliasEq(..)
BoxyUwU marked this conversation as resolved.
Show resolved Hide resolved
| ty::PredicateKind::ObjectSafe(_)
| ty::PredicateKind::ClosureKind(_, _, _)
| ty::PredicateKind::Subtype(_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ fn trait_predicate_kind<'tcx>(
ty::PredicateKind::Clause(ty::Clause::RegionOutlives(_))
| ty::PredicateKind::Clause(ty::Clause::TypeOutlives(_))
| ty::PredicateKind::Clause(ty::Clause::Projection(_))
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::WellFormed(_)
| ty::PredicateKind::Subtype(_)
| ty::PredicateKind::Coerce(_)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_analysis/src/outlives/explicit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<'tcx> ExplicitPredicatesMap<'tcx> {
ty::PredicateKind::Clause(ty::Clause::Trait(..))
| ty::PredicateKind::Clause(ty::Clause::Projection(..))
| ty::PredicateKind::WellFormed(..)
| ty::PredicateKind::AliasEq(..)
BoxyUwU marked this conversation as resolved.
Show resolved Hide resolved
| ty::PredicateKind::ObjectSafe(..)
| ty::PredicateKind::ClosureKind(..)
| ty::PredicateKind::Subtype(..)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
| ty::PredicateKind::Clause(ty::Clause::TypeOutlives(..))
| ty::PredicateKind::WellFormed(..)
| ty::PredicateKind::ObjectSafe(..)
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::ConstEvaluatable(..)
| ty::PredicateKind::ConstEquate(..)
// N.B., this predicate is created by breaking down a
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_typeck/src/method/probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> {
| ty::PredicateKind::ConstEvaluatable(..)
| ty::PredicateKind::ConstEquate(..)
| ty::PredicateKind::Ambiguous
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::TypeWellFormedFromEnv(..) => None,
}
});
Expand Down
6 changes: 1 addition & 5 deletions compiler/rustc_infer/src/infer/canonical/query_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::infer::canonical::{
Canonical, CanonicalQueryResponse, CanonicalVarValues, Certainty, OriginalQueryValues,
QueryOutlivesConstraint, QueryRegionConstraints, QueryResponse,
};
use crate::infer::nll_relate::{NormalizationStrategy, TypeRelating, TypeRelatingDelegate};
use crate::infer::nll_relate::{TypeRelating, TypeRelatingDelegate};
use crate::infer::region_constraints::{Constraint, RegionConstraintData};
use crate::infer::{InferCtxt, InferOk, InferResult, NllRegionVariableOrigin};
use crate::traits::query::{Fallible, NoSolution};
Expand Down Expand Up @@ -717,10 +717,6 @@ impl<'tcx> TypeRelatingDelegate<'tcx> for QueryTypeRelatingDelegate<'_, 'tcx> {
});
}

fn normalization() -> NormalizationStrategy {
NormalizationStrategy::Eager
}

fn forbid_inference_vars() -> bool {
true
}
Expand Down
87 changes: 58 additions & 29 deletions compiler/rustc_infer/src/infer/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::subst::SubstsRef;
use rustc_middle::ty::{
self, FallibleTypeFolder, InferConst, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable,
TypeVisitable,
self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable,
TypeSuperFoldable, TypeVisitable,
};
use rustc_middle::ty::{IntType, UintType};
use rustc_span::{Span, DUMMY_SP};
Expand Down Expand Up @@ -74,7 +74,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: Ty<'tcx>,
) -> RelateResult<'tcx, Ty<'tcx>>
where
R: TypeRelation<'tcx>,
R: ObligationEmittingRelation<'tcx>,
{
let a_is_expected = relation.a_is_expected();

Expand Down Expand Up @@ -122,6 +122,15 @@ impl<'tcx> InferCtxt<'tcx> {
Err(TypeError::Sorts(ty::relate::expected_found(relation, a, b)))
}

(ty::Alias(AliasKind::Projection, _), _) if self.tcx.trait_solver_next() => {
relation.register_type_equate_obligation(a.into(), b.into());
Ok(b)
}
(_, ty::Alias(AliasKind::Projection, _)) if self.tcx.trait_solver_next() => {
relation.register_type_equate_obligation(b.into(), a.into());
Ok(a)
}

_ => ty::relate::super_relate_tys(relation, a, b),
}
}
Expand All @@ -133,7 +142,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>>
where
R: ConstEquateRelation<'tcx>,
R: ObligationEmittingRelation<'tcx>,
{
debug!("{}.consts({:?}, {:?})", relation.tag(), a, b);
if a == b {
Expand Down Expand Up @@ -169,15 +178,15 @@ impl<'tcx> InferCtxt<'tcx> {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
relation.const_equate_obligation(a, b);
relation.register_const_equate_obligation(a, b);
}
return Ok(b);
}
(_, ty::ConstKind::Unevaluated(..)) if self.tcx.lazy_normalization() => {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
relation.const_equate_obligation(a, b);
relation.register_const_equate_obligation(a, b);
}
return Ok(a);
}
Expand Down Expand Up @@ -435,32 +444,21 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
Ok(Generalization { ty, needs_wf })
}

pub fn add_const_equate_obligation(
pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.obligations.extend(obligations.into_iter());
}

pub fn register_predicates(
&mut self,
a_is_expected: bool,
a: ty::Const<'tcx>,
b: ty::Const<'tcx>,
obligations: impl IntoIterator<Item = impl ToPredicate<'tcx>>,
) {
let predicate = if a_is_expected {
ty::PredicateKind::ConstEquate(a, b)
} else {
ty::PredicateKind::ConstEquate(b, a)
};
self.obligations.push(Obligation::new(
self.tcx(),
self.trace.cause.clone(),
self.param_env,
ty::Binder::dummy(predicate),
));
self.obligations.extend(obligations.into_iter().map(|to_pred| {
Obligation::new(self.infcx.tcx, self.trace.cause.clone(), self.param_env, to_pred)
}))
}

pub fn mark_ambiguous(&mut self) {
self.obligations.push(Obligation::new(
self.tcx(),
self.trace.cause.clone(),
self.param_env,
ty::Binder::dummy(ty::PredicateKind::Ambiguous),
));
self.register_predicates([ty::Binder::dummy(ty::PredicateKind::Ambiguous)]);
}
}

Expand Down Expand Up @@ -775,11 +773,42 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
}
}

pub trait ConstEquateRelation<'tcx>: TypeRelation<'tcx> {
pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
/// Register obligations that must hold in order for this relation to hold
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>);

/// Register predicates that must hold in order for this relation to hold. Uses
/// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should
/// be used if control over the obligaton causes is required.
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ToPredicate<'tcx>>,
);

/// Register an obligation that both constants must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>);
fn register_const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };

self.register_predicates([ty::Binder::dummy(if self.tcx().trait_solver_next() {
ty::PredicateKind::AliasEq(a.into(), b.into())
} else {
ty::PredicateKind::ConstEquate(a, b)
})]);
}

/// Register an obligation that both types must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
fn register_type_equate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };

self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasEq(
a.into(),
b.into(),
))]);
}
}

fn int_unification_error<'tcx>(
Expand Down
17 changes: 13 additions & 4 deletions compiler/rustc_infer/src/infer/equate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::combine::{CombineFields, ConstEquateRelation, RelationDir};
use crate::traits::PredicateObligations;

use super::combine::{CombineFields, ObligationEmittingRelation, RelationDir};
use super::Subtype;

use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
Expand Down Expand Up @@ -198,8 +200,15 @@ impl<'tcx> TypeRelation<'tcx> for Equate<'_, '_, 'tcx> {
}
}

impl<'tcx> ConstEquateRelation<'tcx> for Equate<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}

fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}
22 changes: 12 additions & 10 deletions compiler/rustc_infer/src/infer/glb.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Greatest lower bound. See [`lattice`].

use super::combine::CombineFields;
use super::combine::{CombineFields, ObligationEmittingRelation};
use super::lattice::{self, LatticeDir};
use super::InferCtxt;
use super::Subtype;

use crate::infer::combine::ConstEquateRelation;
use crate::traits::{ObligationCause, PredicateObligation};
use crate::traits::{ObligationCause, PredicateObligations};
use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, Ty, TyCtxt};

Expand Down Expand Up @@ -136,10 +135,6 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
&self.fields.trace.cause
}

fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
self.fields.obligations.extend(obligations)
}

fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
sub.relate(v, a)?;
Expand All @@ -152,8 +147,15 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
}
}

impl<'tcx> ConstEquateRelation<'tcx> for Glb<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}

fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}
11 changes: 5 additions & 6 deletions compiler/rustc_infer/src/infer/lattice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
//!
//! [lattices]: https://en.wikipedia.org/wiki/Lattice_(order)

use super::combine::ObligationEmittingRelation;
use super::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use super::InferCtxt;

use crate::traits::{ObligationCause, PredicateObligation};
use rustc_middle::ty::relate::{RelateResult, TypeRelation};
use crate::traits::ObligationCause;
use rustc_middle::ty::relate::RelateResult;
use rustc_middle::ty::TyVar;
use rustc_middle::ty::{self, Ty};

Expand All @@ -30,13 +31,11 @@ use rustc_middle::ty::{self, Ty};
///
/// GLB moves "down" the lattice (to smaller values); LUB moves
/// "up" the lattice (to bigger values).
pub trait LatticeDir<'f, 'tcx>: TypeRelation<'tcx> {
pub trait LatticeDir<'f, 'tcx>: ObligationEmittingRelation<'tcx> {
fn infcx(&self) -> &'f InferCtxt<'tcx>;

fn cause(&self) -> &ObligationCause<'tcx>;

fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>);

fn define_opaque_types(&self) -> bool;

// Relates the type `v` to `a` and `b` such that `v` represents
Expand Down Expand Up @@ -113,7 +112,7 @@ where
| (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }))
if this.define_opaque_types() && def_id.is_local() =>
{
this.add_obligations(
this.register_obligations(
infcx
.handle_opaque_type(a, b, this.a_is_expected(), this.cause(), this.param_env())?
.obligations,
Expand Down
28 changes: 15 additions & 13 deletions compiler/rustc_infer/src/infer/lub.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Least upper bound. See [`lattice`].

use super::combine::CombineFields;
use super::combine::{CombineFields, ObligationEmittingRelation};
use super::lattice::{self, LatticeDir};
use super::InferCtxt;
use super::Subtype;

use crate::infer::combine::ConstEquateRelation;
use crate::traits::{ObligationCause, PredicateObligation};
use crate::traits::{ObligationCause, PredicateObligations};
use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, Ty, TyCtxt};

Expand Down Expand Up @@ -127,12 +126,6 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
}
}

impl<'tcx> ConstEquateRelation<'tcx> for Lub<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
}
}

impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx, 'tcx> {
fn infcx(&self) -> &'infcx InferCtxt<'tcx> {
self.fields.infcx
Expand All @@ -142,10 +135,6 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
&self.fields.trace.cause
}

fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
self.fields.obligations.extend(obligations)
}

fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
sub.relate(a, v)?;
Expand All @@ -157,3 +146,16 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
self.fields.define_opaque_types
}
}

impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}

fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations)
}
}
Loading