diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 92d7cb120e1..cf21e13b87f 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -66,7 +66,7 @@ namespace arith { vector> m_literals; svector> m_eqs; hint_type m_ty; - unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail; + unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0; void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } void add(euf::enode* a, euf::enode* b, bool is_eq) { if (m_eq_tail < m_eqs.size()) diff --git a/src/sat/smt/bv_internalize.cpp b/src/sat/smt/bv_internalize.cpp index 04fde22f1e2..8199f539eae 100644 --- a/src/sat/smt/bv_internalize.cpp +++ b/src/sat/smt/bv_internalize.cpp @@ -431,7 +431,7 @@ namespace bv { sat::literal lit = eq_internalize(n, sum); m_bv2ints.push_back(expr2enode(n)); ctx.push(push_back_vector(m_bv2ints)); - add_unit(lit); + add_unit(lit); } void solver::internalize_int2bv(app* n) { @@ -460,8 +460,8 @@ namespace bv { unsigned sz = bv.get_bv_size(n); numeral mod = power(numeral(2), sz); rhs = m_autil.mk_mod(e, m_autil.mk_int(mod)); - sat::literal eq_lit = eq_internalize(lhs, rhs); - add_unit(eq_lit); + sat::literal eq_lit = eq_internalize(lhs, rhs); + add_unit(eq_lit); expr_ref_vector n_bits(m); get_bits(n_enode, n_bits); @@ -472,8 +472,8 @@ namespace bv { rhs = m_autil.mk_mod(rhs, m_autil.mk_int(2)); rhs = mk_eq(rhs, m_autil.mk_int(1)); lhs = n_bits.get(i); - eq_lit = eq_internalize(lhs, rhs); - add_unit(eq_lit); + eq_lit = eq_internalize(lhs, rhs); + add_unit(eq_lit); } } diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index 33ff829d69f..7156058c757 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -216,10 +216,11 @@ namespace bv { euf::enode* bv2int_arg = bv2int->get_arg(0); for (euf::enode* p : euf::enode_parents(n1->get_root())) { if (bv.is_int2bv(p->get_expr()) && p->get_sort() == bv2int_arg->get_sort() && p->get_root() != bv2int_arg->get_root()) { - euf::enode_pair_vector eqs; - eqs.push_back({ n1, p->get_arg(0) }); - eqs.push_back({ n1, bv2int }); - ctx.propagate(p, bv2int_arg, euf::th_explain::propagate(*this, eqs, p, bv2int_arg)); + theory_var v1 = get_th_var(p); + theory_var v2 = get_th_var(bv2int_arg); + SASSERT(v1 != euf::null_theory_var); + SASSERT(v2 != euf::null_theory_var); + ctx.propagate(p, bv2int_arg, mk_bv2int_justification(v1, v2, n1, p->get_arg(0), bv2int)); break; } } @@ -379,6 +380,11 @@ namespace bv { r.push_back(b); break; } + case bv_justification::kind_t::bv2int: { + ctx.add_antecedent(c.a, c.b); + ctx.add_antecedent(c.a, c.c); + break; + } } if (!probing && ctx.use_drat()) log_drat(c); @@ -386,19 +392,26 @@ namespace bv { void solver::log_drat(bv_justification const& c) { // introduce dummy literal for equality. - sat::literal leq(s().num_vars() + 1, false); - expr_ref eq(m); - if (c.m_kind != bv_justification::kind_t::bit2ne) { + sat::literal leq1(s().num_vars() + 1, false); + sat::literal leq2(s().num_vars() + 2, false); + expr_ref eq1(m), eq2(m); + if (c.m_kind == bv_justification::kind_t::bv2int) { + eq1 = m.mk_eq(c.a->get_expr(), c.b->get_expr()); + eq2 = m.mk_eq(c.a->get_expr(), c.c->get_expr()); + ctx.set_tmp_bool_var(leq1.var(), eq1); + ctx.set_tmp_bool_var(leq2.var(), eq1); + } + else if (c.m_kind != bv_justification::kind_t::bit2ne) { expr* e1 = var2expr(c.m_v1); expr* e2 = var2expr(c.m_v2); - eq = m.mk_eq(e1, e2); - ctx.set_tmp_bool_var(leq.var(), eq); + eq1 = m.mk_eq(e1, e2); + ctx.set_tmp_bool_var(leq1.var(), eq1); } sat::literal_vector lits; switch (c.m_kind) { case bv_justification::kind_t::eq2bit: - lits.push_back(~leq); + lits.push_back(~leq1); lits.push_back(~c.m_antecedent); lits.push_back(c.m_consequent); break; @@ -407,10 +420,10 @@ namespace bv { lits.push_back(c.m_consequent); break; case bv_justification::kind_t::bit2eq: - get_antecedents(leq, c.to_index(), lits, true); + get_antecedents(leq1, c.to_index(), lits, true); for (auto& lit : lits) lit.neg(); - lits.push_back(leq); + lits.push_back(leq1); break; case bv_justification::kind_t::bit2ne: get_antecedents(c.m_consequent, c.to_index(), lits, true); @@ -418,6 +431,14 @@ namespace bv { lit.neg(); lits.push_back(c.m_consequent); break; + case bv_justification::kind_t::bv2int: + get_antecedents(leq1, c.to_index(), lits, true); + get_antecedents(leq2, c.to_index(), lits, true); + for (auto& lit : lits) + lit.neg(); + lits.push_back(leq1); + lits.push_back(leq2); + break; } ctx.get_drat().add(lits, status()); // TBD, a proper way would be to delete the lemma after use. @@ -665,7 +686,9 @@ namespace bv { return out << "bv <- v" << v1 << "[" << cidx << "] != v" << v2 << "[" << cidx << "] " << m_bits[v1][cidx] << " != " << m_bits[v2][cidx]; } case bv_justification::kind_t::ne2bit: - return out << "bv <- " << m_bits[v1] << " != " << m_bits[v2] << " @" << cidx; + return out << "bv <- " << m_bits[v1] << " != " << m_bits[v2] << " @" << cidx; + case bv_justification::kind_t::bv2int: + return out << "bv <- v" << v1 << " == v" << v2 << " <== " << ctx.bpp(c.a) << " == " << ctx.bpp(c.b) << " == " << ctx.bpp(c.c); default: UNREACHABLE(); break; @@ -818,28 +841,41 @@ namespace bv { void* mem = get_region().allocate(bv_justification::get_obj_size()); sat::constraint_base::initialize(mem, this); auto* constraint = new (sat::constraint_base::ptr2mem(mem)) bv_justification(v1, v2, c, a); - return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + auto jst = sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + TRACE("bv", tout << jst << " " << constraint << "\n"); + return jst; } sat::ext_justification_idx solver::mk_bit2eq_justification(theory_var v1, theory_var v2) { void* mem = get_region().allocate(bv_justification::get_obj_size()); sat::constraint_base::initialize(mem, this); auto* constraint = new (sat::constraint_base::ptr2mem(mem)) bv_justification(v1, v2); - return constraint->to_index(); + auto jst = constraint->to_index(); + return jst; } sat::justification solver::mk_bit2ne_justification(unsigned idx, sat::literal c) { void* mem = get_region().allocate(bv_justification::get_obj_size()); sat::constraint_base::initialize(mem, this); auto* constraint = new (sat::constraint_base::ptr2mem(mem)) bv_justification(idx, c); - return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + auto jst = sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + return jst; } sat::justification solver::mk_ne2bit_justification(unsigned idx, theory_var v1, theory_var v2, sat::literal c, sat::literal a) { void* mem = get_region().allocate(bv_justification::get_obj_size()); sat::constraint_base::initialize(mem, this); auto* constraint = new (sat::constraint_base::ptr2mem(mem)) bv_justification(idx, v1, v2, c, a); - return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + auto jst = sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + return jst; + } + + sat::ext_constraint_idx solver::mk_bv2int_justification(theory_var v1, theory_var v2, euf::enode* a, euf::enode* b, euf::enode* c) { + void* mem = get_region().allocate(bv_justification::get_obj_size()); + sat::constraint_base::initialize(mem, this); + auto* constraint = new (sat::constraint_base::ptr2mem(mem)) bv_justification(v1, v2, a, b, c); + auto jst = constraint->to_index(); + return jst; } bool solver::assign_bit(literal consequent, theory_var v1, theory_var v2, unsigned idx, literal antecedent, bool propagate_eqc) { diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index b77343e3055..4166ebf7bec 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -51,13 +51,21 @@ namespace bv { }; struct bv_justification { - enum kind_t { eq2bit, ne2bit, bit2eq, bit2ne }; + enum kind_t { eq2bit, ne2bit, bit2eq, bit2ne, bv2int }; kind_t m_kind; - unsigned m_idx{ UINT_MAX }; - theory_var m_v1{ euf::null_theory_var }; - theory_var m_v2 { euf::null_theory_var }; - sat::literal m_consequent; - sat::literal m_antecedent; + unsigned m_idx = UINT_MAX; + theory_var m_v1 = euf::null_theory_var; + theory_var m_v2 = euf::null_theory_var; + union { + struct { + sat::literal m_consequent; + sat::literal m_antecedent; + }; + struct { + euf::enode* a, *b, *c; + }; + }; + bv_justification(theory_var v1, theory_var v2, sat::literal c, sat::literal a) : m_kind(bv_justification::kind_t::eq2bit), m_v1(v1), m_v2(v2), m_consequent(c), m_antecedent(a) {} bv_justification(theory_var v1, theory_var v2): @@ -66,6 +74,8 @@ namespace bv { m_kind(bv_justification::kind_t::bit2ne), m_idx(idx), m_consequent(c) {} bv_justification(unsigned idx, theory_var v1, theory_var v2, sat::literal c, sat::literal a) : m_kind(bv_justification::kind_t::ne2bit), m_idx(idx), m_v1(v1), m_v2(v2), m_consequent(c), m_antecedent(a) {} + bv_justification(theory_var v1, theory_var v2, euf::enode* a, euf::enode* b, euf::enode* c): + m_kind(bv_justification::kind_t::bv2int), m_v1(v1), m_v2(v2), a(a), b(b), c(c) {} sat::ext_constraint_idx to_index() const { return sat::constraint_base::mem2base(this); } @@ -79,6 +89,7 @@ namespace bv { sat::ext_justification_idx mk_bit2eq_justification(theory_var v1, theory_var v2); sat::justification mk_bit2ne_justification(unsigned idx, sat::literal c); sat::justification mk_ne2bit_justification(unsigned idx, theory_var v1, theory_var v2, sat::literal c, sat::literal a); + sat::ext_constraint_idx mk_bv2int_justification(theory_var v1, theory_var v2, euf::enode* a, euf::enode* b, euf::enode* c); void log_drat(bv_justification const& c);