Skip to content

Commit

Permalink
experimental feature to access congruence closure of SimpleSolver
Browse files Browse the repository at this point in the history
This update includes an experimental feature to access a congruence closure data-structure after search.
It comes with several caveats as pre-processing is free to eliminate terms. It is therefore necessary to use a solver that does not eliminate the terms you want to track for congruence of. This is partially addressed by using SimpleSolver or incremental mode solving.

```python
from z3 import *
s = SimpleSolver()
x, y, z = Ints('x y z')
s.add(x == y)
s.add(y == z)
s.check()
print(s.root(x), s.root(y), s.root(z))
print(s.next(x), s.next(y), s.next(z))
```
  • Loading branch information
NikolajBjorner committed Dec 31, 2022
1 parent c0f1f33 commit f6d411d
Show file tree
Hide file tree
Showing 21 changed files with 145 additions and 12 deletions.
20 changes: 20 additions & 0 deletions src/api/api_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,26 @@ extern "C" {
Z3_CATCH_RETURN(nullptr);
}

Z3_ast Z3_API Z3_solver_congruence_root(Z3_context c, Z3_solver s, Z3_ast a) {
Z3_TRY;
LOG_Z3_solver_congruence_root(c, s, a);
RESET_ERROR_CODE();
init_solver(c, s);
expr* r = to_solver_ref(s)->congruence_root(to_expr(a));
RETURN_Z3(of_expr(r));
Z3_CATCH_RETURN(nullptr);
}

Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a) {
Z3_TRY;
LOG_Z3_solver_congruence_next(c, s, a);
RESET_ERROR_CODE();
init_solver(c, s);
expr* sib = to_solver_ref(s)->congruence_next(to_expr(a));
RETURN_Z3(of_expr(sib));
Z3_CATCH_RETURN(nullptr);
}

class api_context_obj : public user_propagator::context_obj {
api::context* c;
public:
Expand Down
16 changes: 16 additions & 0 deletions src/api/python/z3/z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7241,6 +7241,22 @@ def cube_vars(self):
cube are likely more useful to cube on."""
return self.cube_vs

def root(self, t):
t = _py2expr(t, self.ctx)
"""Retrieve congruence closure root of the term t relative to the current search state
The function primarily works for SimpleSolver. Terms and variables that are
eliminated during pre-processing are not visible to the congruence closure.
"""
return _to_expr_ref(Z3_solver_congruence_root(self.ctx.ref(), self.solver, t.ast), self.ctx)

def next(self, t):
t = _py2expr(t, self.ctx)
"""Retrieve congruence closure sibling of the term t relative to the current search state
The function primarily works for SimpleSolver. Terms and variables that are
eliminated during pre-processing are not visible to the congruence closure.
"""
return _to_expr_ref(Z3_solver_congruence_next(self.ctx.ref(), self.solver, t.ast), self.ctx)

def proof(self):
"""Return a proof for the last `check()`. Proof construction must be enabled."""
return _to_expr_ref(Z3_solver_get_proof(self.ctx.ref(), self.solver), self.ctx)
Expand Down
20 changes: 20 additions & 0 deletions src/api/z3_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6882,6 +6882,26 @@ extern "C" {
*/
void Z3_API Z3_solver_get_levels(Z3_context c, Z3_solver s, Z3_ast_vector literals, unsigned sz, unsigned levels[]);

/**
\brief retrieve the congruence closure root of an expression.
The root is retrieved relative to the state where the solver was in when it completed.
If it completed during a set of case splits, the congruence roots are relative to these case splits.
That is, the congruences are not consequences but they are true under the current state.
def_API('Z3_solver_congruence_root', AST, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/
Z3_ast Z3_API Z3_solver_congruence_root(Z3_context c, Z3_solver s, Z3_ast a);


/**
\brief retrieve the next expression in the congruence class. The set of congruent siblings form a cyclic list.
Repeated calls on the siblings will result in returning to the original expression.
def_API('Z3_solver_congruence_next', AST, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/
Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a);


/**
\brief register a callback to that retrieves assumed, inferred and deleted clauses during search.
Expand Down
21 changes: 14 additions & 7 deletions src/ast/converters/expr_inverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class basic_expr_inverter : public iexpr_inverter {
*
*/

bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override {
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override {
SASSERT(f->get_family_id() == m.get_basic_family_id());
switch (f->get_decl_kind()) {
case OP_ITE:
Expand Down Expand Up @@ -233,7 +233,7 @@ class arith_expr_inverter : public iexpr_inverter {
}


bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override {
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override {
SASSERT(f->get_family_id() == a.get_family_id());
switch (f->get_decl_kind()) {
case OP_ADD:
Expand Down Expand Up @@ -531,7 +531,7 @@ class bv_expr_inverter : public iexpr_inverter {
* y := 0
*
*/
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override {
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override {
SASSERT(f->get_family_id() == bv.get_family_id());
switch (f->get_decl_kind()) {
case OP_BADD:
Expand Down Expand Up @@ -611,7 +611,7 @@ class array_expr_inverter : public iexpr_inverter {

family_id get_fid() const override { return a.get_family_id(); }

bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override {
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override {
SASSERT(f->get_family_id() == a.get_family_id());
switch (f->get_decl_kind()) {
case OP_SELECT:
Expand Down Expand Up @@ -679,7 +679,7 @@ class dt_expr_inverter : public iexpr_inverter {
* head(x) -> fresh
* x := cons(fresh, arb)
*/
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override {
bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override {
if (dt.is_accessor(f)) {
SASSERT(num == 1);
if (uncnstr(args[0])) {
Expand Down Expand Up @@ -799,7 +799,7 @@ expr_inverter::expr_inverter(ast_manager& m): iexpr_inverter(m) {
}


bool expr_inverter::operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& new_expr, expr_ref& side_cond) {
bool expr_inverter::operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& new_expr, proof_ref& pr) {
if (num == 0)
return false;

Expand All @@ -812,7 +812,7 @@ bool expr_inverter::operator()(func_decl* f, unsigned num, expr* const* args, ex
return false;

auto* p = m_inverters.get(fid, nullptr);
return p && (*p)(f, num, args, new_expr, side_cond);
return p && (*p)(f, num, args, new_expr, pr);
}

bool expr_inverter::mk_diff(expr* t, expr_ref& r) {
Expand Down Expand Up @@ -849,3 +849,10 @@ void expr_inverter::set_model_converter(generic_model_converter* mc) {
if (p)
p->set_model_converter(mc);
}

void expr_inverter::set_produce_proofs(bool pr) {
m_produce_proofs = pr;
for (auto* p : m_inverters)
if (p)
p->set_produce_proofs(pr);
}
7 changes: 5 additions & 2 deletions src/ast/converters/expr_inverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class iexpr_inverter {
ast_manager& m;
std::function<bool(expr*)> m_is_var;
generic_model_converter_ref m_mc;
bool m_produce_proofs = false;

bool uncnstr(expr* e) const { return m_is_var(e); }
bool uncnstr(unsigned num, expr * const * args) const;
Expand All @@ -37,8 +38,9 @@ class iexpr_inverter {
virtual ~iexpr_inverter() {}
virtual void set_is_var(std::function<bool(expr*)>& is_var) { m_is_var = is_var; }
virtual void set_model_converter(generic_model_converter* mc) { m_mc = mc; }
virtual void set_produce_proofs(bool p) { m_produce_proofs = true; }

virtual bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, expr_ref& side_cond) = 0;
virtual bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, proof_ref& pr) = 0;
virtual bool mk_diff(expr* t, expr_ref& r) = 0;
virtual family_id get_fid() const = 0;
};
Expand All @@ -49,9 +51,10 @@ class expr_inverter : public iexpr_inverter {
public:
expr_inverter(ast_manager& m);
~expr_inverter() override;
bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, expr_ref& side_cond) override;
bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, proof_ref& pr) override;
bool mk_diff(expr* t, expr_ref& r) override;
void set_is_var(std::function<bool(expr*)>& is_var) override;
void set_model_converter(generic_model_converter* mc) override;
void set_produce_proofs(bool p) override;
family_id get_fid() const override { return null_family_id; }
};
7 changes: 4 additions & 3 deletions src/ast/simplifiers/elim_unconstrained.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ bool elim_unconstrained::is_var_lt(int v1, int v2) const {
void elim_unconstrained::eliminate() {

while (!m_heap.empty()) {
expr_ref r(m), side_cond(m);
expr_ref r(m);
proof_ref pr(m);
int v = m_heap.erase_min();
node& n = get_node(v);
if (n.m_refcount == 0)
Expand All @@ -84,7 +85,7 @@ void elim_unconstrained::eliminate() {
unsigned sz = m_args.size();
for (expr* arg : *to_app(t))
m_args.push_back(reconstruct_term(get_node(arg)));
bool inverted = m_inverter(t->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz, r, side_cond);
bool inverted = m_inverter(t->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz, r, pr);
n.m_refcount = 0;
m_args.shrink(sz);
if (!inverted) {
Expand Down Expand Up @@ -113,7 +114,7 @@ void elim_unconstrained::eliminate() {

IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(get_node(v).m_orig, m) << " " << mk_bounded_pp(t, m) << " -> " << r << " " << get_node(e).m_refcount << "\n";);

SASSERT(!side_cond && "not implemented to add side conditions\n");
SASSERT(!pr && "not implemented to add proofs\n");
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/muz/spacer/spacer_iuc_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class iuc_solver : public solver {
void set_phase(phase* p) override { m_solver.set_phase(p); }
void move_to_front(expr* e) override { m_solver.move_to_front(e); }
expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); }
expr* congruence_root(expr* e) override { return e; }
expr* congruence_next(expr* e) override { return e; }
void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); }
expr_ref_vector get_trail(unsigned max_level) override { return m_solver.get_trail(max_level); }

Expand Down
2 changes: 2 additions & 0 deletions src/opt/opt_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ namespace opt {
void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override;
expr_ref_vector get_trail(unsigned max_level) override { return m_context.get_trail(max_level); }
expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); }
expr* congruence_root(expr* e) override { return e; }
expr* congruence_next(expr* e) override { return e; }
void set_phase(expr* e) override { m_context.set_phase(e); }
phase* get_phase() override { return m_context.get_phase(); }
void set_phase(phase* p) override { m_context.set_phase(p); }
Expand Down
4 changes: 4 additions & 0 deletions src/sat/sat_solver/inc_sat_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ class inc_sat_solver : public solver {
}
return fmls;
}

expr* congruence_next(expr* e) override { return e; }
expr* congruence_root(expr* e) override { return e; }


lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) override {
init_preprocess();
Expand Down
3 changes: 3 additions & 0 deletions src/sat/sat_solver/sat_smt_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ class sat_smt_solver : public solver {
set_reason_unknown(m_solver.get_reason_unknown());
return fmls;
}

expr* congruence_next(expr* e) override { return e; }
expr* congruence_root(expr* e) override { return e; }


lbool find_mutexes(expr_ref_vector const& vars, vector<expr_ref_vector>& mutexes) override {
Expand Down
14 changes: 14 additions & 0 deletions src/smt/smt_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ namespace smt {
return out;
}

expr* kernel::congruence_root(expr * e) {
smt::enode* n = m_imp->m_kernel.find_enode(e);
if (!n)
return e;
return n->get_root()->get_expr();
}

expr* kernel::congruence_next(expr * e) {
smt::enode* n = m_imp->m_kernel.find_enode(e);
if (!n)
return e;
return n->get_next()->get_expr();
}

void kernel::collect_statistics(::statistics & st) const {
m_imp->m_kernel.collect_statistics(st);
}
Expand Down
7 changes: 7 additions & 0 deletions src/smt/smt_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ namespace smt {
*/
expr_ref_vector cubes(unsigned depth);

/**
\brief access congruence closure
*/
expr* congruence_next(expr* e);

expr* congruence_root(expr* e);


/**
\brief retrieve depth of variables from decision stack.
Expand Down
4 changes: 4 additions & 0 deletions src/smt/smt_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ namespace {
m_context.get_units(units);
}

expr* congruence_next(expr* e) override { return m_context.congruence_next(e); }
expr* congruence_root(expr* e) override { return m_context.congruence_root(e); }


expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override {
ast_manager& m = get_manager();
if (!m_cuber) {
Expand Down
4 changes: 4 additions & 0 deletions src/solver/combined_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ class combined_solver : public solver {
return m_solver2->cube(vars, backtrack_level);
}

expr* congruence_next(expr* e) override { switch_inc_mode(); return m_solver2->congruence_next(e); }
expr* congruence_root(expr* e) override { switch_inc_mode(); return m_solver2->congruence_root(e); }


expr * get_assumption(unsigned idx) const override {
unsigned c1 = m_solver1->get_num_assumptions();
if (idx < c1) return m_solver1->get_assumption(idx);
Expand Down
9 changes: 9 additions & 0 deletions src/solver/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ class solver : public check_sat_result, public user_propagator::core {

virtual expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) = 0;

/**
\brief retrieve congruence closure root.
*/
virtual expr* congruence_root(expr* e) = 0;

/**
\brief retrieve congruence closure sibling
*/
virtual expr* congruence_next(expr* e) = 0;

/**
\brief Display the content of this solver.
Expand Down
3 changes: 3 additions & 0 deletions src/solver/solver_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class pool_solver : public solver_na2as {

expr_ref_vector cube(expr_ref_vector& vars, unsigned ) override { return expr_ref_vector(m); }

expr* congruence_next(expr* e) override { return e; }
expr* congruence_root(expr* e) override { return e; }

ast_manager& get_manager() const override { return m_base->get_manager(); }

void refresh(solver* new_base) {
Expand Down
3 changes: 3 additions & 0 deletions src/solver/tactic2solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class tactic2solver : public solver_na2as {
return expr_ref_vector(get_manager());
}

expr* congruence_next(expr* e) override { return e; }
expr* congruence_root(expr* e) override { return e; }

model_converter_ref get_model_converter() const override { return m_mc; }

void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override {
Expand Down
2 changes: 2 additions & 0 deletions src/tactic/fd_solver/bounded_int2bv_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class bounded_int2bv_solver : public solver_na2as {
void set_reason_unknown(char const* msg) override { m_solver->set_reason_unknown(msg); }
void get_labels(svector<symbol> & r) override { m_solver->get_labels(r); }
ast_manager& get_manager() const override { return m; }
expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); }
expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); }
expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); }
lbool find_mutexes(expr_ref_vector const& vars, vector<expr_ref_vector>& mutexes) override { return m_solver->find_mutexes(vars, mutexes); }
lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override {
Expand Down
3 changes: 3 additions & 0 deletions src/tactic/fd_solver/enum2bv_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class enum2bv_solver : public solver_na2as {
expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override {
return m_solver->cube(vars, backtrack_level);
}
expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); }
expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); }


lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override {
datatype_util dt(m);
Expand Down
2 changes: 2 additions & 0 deletions src/tactic/fd_solver/pb2bv_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class pb2bv_solver : public solver_na2as {
void get_labels(svector<symbol> & r) override { m_solver->get_labels(r); }
ast_manager& get_manager() const override { return m; }
expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); }
expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); }
expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); }
lbool find_mutexes(expr_ref_vector const& vars, vector<expr_ref_vector>& mutexes) override { return m_solver->find_mutexes(vars, mutexes); }
lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override {
flush_assertions();
Expand Down
4 changes: 4 additions & 0 deletions src/tactic/fd_solver/smtfd_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,10 @@ namespace smtfd {
expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override {
return expr_ref_vector(m);
}

expr* congruence_root(expr* e) override { return e; }

expr* congruence_next(expr* e) override { return e; }

lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override {
return l_undef;
Expand Down

0 comments on commit f6d411d

Please sign in to comment.