diff --git a/src/ast/rewriter/bool_rewriter.cpp b/src/ast/rewriter/bool_rewriter.cpp index 9afab7a2911..13a392d247d 100644 --- a/src/ast/rewriter/bool_rewriter.cpp +++ b/src/ast/rewriter/bool_rewriter.cpp @@ -699,6 +699,22 @@ app* bool_rewriter::mk_eq(expr* lhs, expr* rhs) { return m().mk_eq(lhs, rhs); } +bool bool_rewriter::try_ite_eq(expr* lhs, expr* rhs, expr_ref& r) { + expr* c, *t, *e; + if (!m().is_ite(lhs, c, t, e)) + return false; + if (m().are_equal(t, rhs) && m().are_distinct(e, rhs)) { + r = c; + return true; + } + if (m().are_equal(e, rhs) && m().are_distinct(t, rhs)) { + r = m().mk_not(c); + return true; + } + return false; +} + + br_status bool_rewriter::mk_eq_core(expr * lhs, expr * rhs, expr_ref & result) { if (m().are_equal(lhs, rhs)) { result = m().mk_true(); @@ -713,6 +729,12 @@ br_status bool_rewriter::mk_eq_core(expr * lhs, expr * rhs, expr_ref & result) { br_status r = BR_FAILED; + if (try_ite_eq(lhs, rhs, result)) + return BR_REWRITE1; + + if (try_ite_eq(rhs, lhs, result)) + return BR_REWRITE1; + if (m_ite_extra_rules) { if (m().is_ite(lhs) && m().is_value(rhs)) { r = try_ite_value(to_app(lhs), to_app(rhs), result); diff --git a/src/ast/rewriter/bool_rewriter.h b/src/ast/rewriter/bool_rewriter.h index 7c840b6478c..421811ed4d1 100644 --- a/src/ast/rewriter/bool_rewriter.h +++ b/src/ast/rewriter/bool_rewriter.h @@ -71,6 +71,8 @@ class bool_rewriter { void mk_and_as_or(unsigned num_args, expr * const * args, expr_ref & result); + bool try_ite_eq(expr* lhs, expr* rhs, expr_ref& r); + expr * mk_or_app(unsigned num_args, expr * const * args); bool simp_nested_not_or(unsigned num_args, expr * const * args, expr_fast_mark1 & neg_lits, expr_fast_mark2 & pos_lits, expr_ref & result); expr * simp_arg(expr * arg, expr_fast_mark1 & neg_lits, expr_fast_mark2 & pos_lits, bool & modified); diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index aa79bab5b41..76d234d8d5a 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -390,7 +390,7 @@ namespace q { m_qs.log_instantiation(lits, &j); euf::th_proof_hint* ph = nullptr; if (ctx.use_drat()) - ph = q_proof_hint::mk(ctx, j.m_generation, lits, j.m_clause.num_decls(), j.m_binding); + ph = q_proof_hint::mk(ctx, m_ematch, j.m_generation, lits, j.m_clause.num_decls(), j.m_binding); m_qs.add_clause(lits, ph); } diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index cbeb34679bf..f7de55fb813 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -90,6 +90,7 @@ namespace q { unsigned_vector m_clause_queue; euf::enode_pair_vector m_evidence; bool m_enable_propagate = true; + symbol m_ematch = symbol("ematch"); euf::enode* const* copy_nodes(clause& c, euf::enode* const* _binding); binding* tmp_binding(clause& c, app* pat, euf::enode* const* _binding); diff --git a/src/sat/smt/q_mbi.cpp b/src/sat/smt/q_mbi.cpp index 539c4f943ba..07d4880c966 100644 --- a/src/sat/smt/q_mbi.cpp +++ b/src/sat/smt/q_mbi.cpp @@ -71,7 +71,7 @@ namespace q { for (auto const& [qlit, fml, inst, generation] : m_instantiations) { euf::solver::scoped_generation sg(ctx, generation + 1); sat::literal lit = ~ctx.mk_literal(fml); - auto* ph = ctx.use_drat()? q_proof_hint::mk(ctx, generation, ~qlit, lit, inst.size(), inst.data()) : nullptr; + auto* ph = ctx.use_drat()? q_proof_hint::mk(ctx, m_mbqi, generation, ~qlit, lit, inst.size(), inst.data()) : nullptr; m_qs.add_clause(~qlit, lit, ph); m_qs.log_instantiation(~qlit, lit); } diff --git a/src/sat/smt/q_mbi.h b/src/sat/smt/q_mbi.h index 96e3ba56f97..71a15be7473 100644 --- a/src/sat/smt/q_mbi.h +++ b/src/sat/smt/q_mbi.h @@ -72,6 +72,7 @@ namespace q { unsigned m_max_choose_candidates = 10; unsigned m_generation_bound = UINT_MAX; unsigned m_generation_max = UINT_MAX; + symbol m_mbqi = symbol("mbqi"); typedef std::tuple instantiation_t; vector m_instantiations; vector m_defs; diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index fff11898c7d..aec10607248 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -364,10 +364,10 @@ namespace q { } } - q_proof_hint* q_proof_hint::mk(euf::solver& s, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings) { + q_proof_hint* q_proof_hint::mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings) { SASSERT(n > 0); auto* mem = s.get_region().allocate(q_proof_hint::get_obj_size(n, lits.size())); - q_proof_hint* ph = new (mem) q_proof_hint(generation, n, lits.size()); + q_proof_hint* ph = new (mem) q_proof_hint(method, generation, n, lits.size()); for (unsigned i = 0; i < n; ++i) ph->m_bindings[i] = bindings[i]->get_expr(); for (unsigned i = 0; i < lits.size(); ++i) @@ -375,10 +375,10 @@ namespace q { return ph; } - q_proof_hint* q_proof_hint::mk(euf::solver& s, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings) { + q_proof_hint* q_proof_hint::mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings) { SASSERT(n > 0); auto* mem = s.get_region().allocate(q_proof_hint::get_obj_size(n, 2)); - q_proof_hint* ph = new (mem) q_proof_hint(generation, n, 2); + q_proof_hint* ph = new (mem) q_proof_hint(method, generation, n, 2); for (unsigned i = 0; i < n; ++i) ph->m_bindings[i] = bindings[i]; ph->m_literals[0] = l1; @@ -402,6 +402,7 @@ namespace q { args.push_back(s.literal2expr(~m_literals[i])); args.push_back(binding); args.push_back(m.mk_app(symbol("gen"), 1, gens, range)); + args.push_back(m.mk_const(m_method, range)); return m.mk_app(symbol("inst"), args.size(), args.data(), range); } diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index d0581f85203..a7220e68b9d 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -30,21 +30,23 @@ namespace euf { namespace q { struct q_proof_hint : public euf::th_proof_hint { + symbol m_method; unsigned m_generation; unsigned m_num_bindings; unsigned m_num_literals; sat::literal* m_literals; expr* m_bindings[0]; - q_proof_hint(unsigned g, unsigned b, unsigned l) { + q_proof_hint(symbol const& method, unsigned g, unsigned b, unsigned l) { + m_method = method; m_generation = g; m_num_bindings = b; m_num_literals = l; m_literals = reinterpret_cast(m_bindings + m_num_bindings); } static size_t get_obj_size(unsigned num_bindings, unsigned num_lits) { return sizeof(q_proof_hint) + num_bindings*sizeof(expr*) + num_lits*sizeof(sat::literal); } - static q_proof_hint* mk(euf::solver& s, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings); - static q_proof_hint* mk(euf::solver& s, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings); + static q_proof_hint* mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings); + static q_proof_hint* mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings); expr* get_hint(euf::solver& s) const override; };