Skip to content

Commit

Permalink
add demodulator tactic based on demodulator-simplifier
Browse files Browse the repository at this point in the history
- some handling for commutative operators
- fix bug in demodulator_index where fwd and bwd are swapped
  • Loading branch information
NikolajBjorner committed Dec 5, 2022
1 parent 8709595 commit de916f5
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 37 deletions.
18 changes: 16 additions & 2 deletions src/ast/simplifiers/demodulator_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,20 @@ void demodulator_index::remove_bwd(expr* e, unsigned i) {
for_each_expr(p, e);
}

std::ostream& demodulator_index::display(std::ostream& out) const {
out << "forward\n";
for (auto& [k, v] : m_fwd_index)
out << mk_pp(k, m) << " : " << *v << "\n";
out << "backward\n";
for (auto& [k, v] : m_bwd_index)
out << mk_pp(k, m) << " : " << *v << "\n";
return out;
}


demodulator_simplifier::demodulator_simplifier(ast_manager& m, params_ref const& p, dependent_expr_state& st):
dependent_expr_simplifier(m, st),
m_index(m),
m_util(m),
m_match_subst(m),
m_rewriter(m),
Expand Down Expand Up @@ -104,18 +116,19 @@ bool demodulator_simplifier::rewrite1(func_decl* f, expr_ref_vector const& args,
if (!m_index.find_fwd(f, set))
return false;

TRACE("demodulator", tout << "trying to rewrite: " << f->get_name() << " args:\n" << args << "\n";);
TRACE("demodulator", tout << "trying to rewrite: " << f->get_name() << " args:" << args << "\n"; m_index.display(tout));

for (unsigned i : *set) {

auto const& [lhs, rhs] = m_rewrites[i];

TRACE("demodulator", tout << "Matching with demodulator: " << i << " " << mk_pp(lhs, m) << "\n");

if (lhs->get_num_args() != args.size())
continue;

SASSERT(lhs->get_decl() == f);

TRACE("demodulator", tout << "Matching with demodulator: " << mk_pp(lhs, m) << "\n");

if (m_match_subst(lhs, rhs, args.data(), np)) {
TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(rhs, m) << "\n===>\n" << np << "\n");
Expand Down Expand Up @@ -186,6 +199,7 @@ void demodulator_simplifier::reduce() {
rewrite(i);
if (m_util.is_demodulator(fml(i), large, small)) {
func_decl* f = large->get_decl();
TRACE("demodulator", tout << i << " " << mk_pp(fml(i), m) << ": " << large << " ==> " << small << "\n");
reschedule_processed(f);
reschedule_demodulators(f, large);
m_index.insert_fwd(f, i);
Expand Down
9 changes: 7 additions & 2 deletions src/ast/simplifiers/demodulator_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@ Module Name:
#include "util/uint_set.h"

class demodulator_index {
ast_manager& m;
obj_map<func_decl, uint_set*> m_fwd_index, m_bwd_index;
void add(func_decl* f, unsigned i, obj_map<func_decl, uint_set*>& map);
void del(func_decl* f, unsigned i, obj_map<func_decl, uint_set*>& map);
public:
demodulator_index(ast_manager& m): m(m) {}
~demodulator_index();
void reset();
void insert_fwd(func_decl* f, unsigned i) { add(f, i, m_fwd_index); }
void remove_fwd(func_decl* f, unsigned i) { del(f, i, m_fwd_index); }
void insert_bwd(expr* e, unsigned i);
void remove_bwd(expr* e, unsigned i);
bool find_fwd(func_decl* f, uint_set*& s) { return m_bwd_index.find(f, s); }
bool find_bwd(func_decl* f, uint_set*& s) { return m_fwd_index.find(f, s); }
bool find_fwd(func_decl* f, uint_set*& s) { return m_fwd_index.find(f, s); }
bool find_bwd(func_decl* f, uint_set*& s) { return m_bwd_index.find(f, s); }
bool empty() const { return m_fwd_index.empty(); }
std::ostream& display(std::ostream& out) const;
};

class demodulator_simplifier : public dependent_expr_simplifier {
Expand All @@ -56,4 +59,6 @@ class demodulator_simplifier : public dependent_expr_simplifier {
demodulator_simplifier(ast_manager& m, params_ref const& p, dependent_expr_state& st);

void reduce() override;

char const* name() const override { return "demodulator"; }
};
92 changes: 59 additions & 33 deletions src/ast/substitution/demodulator_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,19 +852,45 @@ bool demodulator_match_subst::match_args(app * lhs, expr * const * args) {
m_cache.reset();
m_todo.reset();

auto fill_commutative = [&](app* lhs, expr * const* args) {
if (!lhs->get_decl()->is_commutative())
return false;
if (lhs->get_num_args() != 2)
return false;
expr* l1 = lhs->get_arg(0);
expr* l2 = lhs->get_arg(1);
expr* r1 = args[0];
expr* r2 = args[1];

if (is_app(l1) && is_app(r1) && to_app(l1)->get_decl() != to_app(r1)->get_decl()) {
m_all_args_eq = false;
m_todo.push_back(expr_pair(l1, r2));
m_todo.push_back(expr_pair(l2, r1));
return true;
}
if (is_app(l2) && is_app(r2) && to_app(l2)->get_decl() != to_app(r2)->get_decl()) {
m_all_args_eq = false;
m_todo.push_back(expr_pair(l1, r2));
m_todo.push_back(expr_pair(l2, r1));
return true;
}
return false;
};
// fill todo-list, and perform quick success/failure tests
m_all_args_eq = true;
unsigned num_args = lhs->get_num_args();
for (unsigned i = 0; i < num_args; i++) {
expr * t_arg = lhs->get_arg(i);
expr * i_arg = args[i];
if (t_arg != i_arg)
m_all_args_eq = false;
if (is_app(t_arg) && is_app(i_arg) && to_app(t_arg)->get_decl() != to_app(i_arg)->get_decl()) {
// quick failure...
return false;
if (!fill_commutative(lhs, args)) {
for (unsigned i = 0; i < num_args; i++) {
expr * t_arg = lhs->get_arg(i);
expr * i_arg = args[i];
if (t_arg != i_arg)
m_all_args_eq = false;
if (is_app(t_arg) && is_app(i_arg) && to_app(t_arg)->get_decl() != to_app(i_arg)->get_decl()) {
// quick failure...
return false;
}
m_todo.push_back(expr_pair(t_arg, i_arg));
}
m_todo.push_back(expr_pair(t_arg, i_arg));
}

if (m_all_args_eq) {
Expand All @@ -875,48 +901,47 @@ bool demodulator_match_subst::match_args(app * lhs, expr * const * args) {
m_subst.reset();

while (!m_todo.empty()) {
expr_pair const & p = m_todo.back();
auto const & [a, b] = m_todo.back();

if (is_var(p.first)) {
if (is_var(a)) {
expr_offset r;
if (m_subst.find(to_var(p.first), 0, r)) {
if (r.get_expr() != p.second)
if (m_subst.find(to_var(a), 0, r)) {
if (r.get_expr() != b)
return false;
}
else {
m_subst.insert(to_var(p.first), 0, expr_offset(p.second, 1));
m_subst.insert(to_var(a), 0, expr_offset(b, 1));
}
m_todo.pop_back();
continue;
}

if (is_var(p.second))
if (is_var(b))
return false;

// we may have nested quantifiers.
if (is_quantifier(p.first) || is_quantifier(p.second))
if (is_quantifier(a) || is_quantifier(b))
return false;

SASSERT(is_app(p.first) && is_app(p.second));
SASSERT(is_app(a) && is_app(b));

if (to_app(p.first)->is_ground() && !to_app(p.second)->is_ground())
if (to_app(a)->is_ground() && !to_app(b)->is_ground())
return false;

if (p.first == p.second && to_app(p.first)->is_ground()) {
SASSERT(to_app(p.second)->is_ground());
if (a == b && to_app(a)->is_ground()) {
m_todo.pop_back();
continue;
}

if (m_cache.contains(p)) {
if (m_cache.contains(expr_pair(a, b))) {
m_todo.pop_back();
continue;
}

if (p.first == p.second) {
// p.first and p.second is not ground...
if (a == b) {
// a and b is not ground...

// Traverse p.first and check whether every variable X:0 in p.first
// Traverse a and check whether every variable X:0 in a
// 1) is unbounded (then we bind X:0 -> X:1)
// 2) or, is already bounded to X:1
// If that is, the case, we execute:
Expand All @@ -927,19 +952,19 @@ bool demodulator_match_subst::match_args(app * lhs, expr * const * args) {
// return false;
match_args_aux_proc proc(m_subst);
try {
for_each_expr(proc, p.first);
for_each_expr(proc, a);
// succeeded
m_todo.pop_back();
m_cache.insert(p);
m_cache.insert(expr_pair(a, b));
continue;
}
catch (const match_args_aux_proc::no_match &) {
return false;
}
}

app * n1 = to_app(p.first);
app * n2 = to_app(p.second);
app * n1 = to_app(a);
app * n2 = to_app(b);

if (n1->get_decl() != n2->get_decl())
return false;
Expand All @@ -953,12 +978,13 @@ bool demodulator_match_subst::match_args(app * lhs, expr * const * args) {
if (num_args1 == 0)
continue;

m_cache.insert(p);
unsigned j = num_args1;
while (j > 0) {
--j;
m_cache.insert(expr_pair(a, b));

if (fill_commutative(n1, n2->get_args()))
continue;

for (unsigned j = num_args1; j-- > 0; )
m_todo.push_back(expr_pair(n1->get_arg(j), n2->get_arg(j)));
}
}
return true;
}
Expand Down
1 change: 1 addition & 0 deletions src/tactic/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ z3_add_component(core_tactics
cofactor_term_ite_tactic.h
collect_statistics_tactic.h
ctx_simplify_tactic.h
demodulator_tactic.h
der_tactic.h
distribute_forall_tactic.h
dom_simplify_tactic.h
Expand Down
40 changes: 40 additions & 0 deletions src/tactic/core/demodulator_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*++
Copyright (c) 2022 Microsoft Corporation
Module Name:
demodulator_tactic.h
Abstract:
Tactic for solving variables
Author:
Nikolaj Bjorner (nbjorner) 2022-10-30
--*/
#pragma once

#include "util/params.h"
#include "tactic/tactic.h"
#include "tactic/dependent_expr_state_tactic.h"
#include "ast/simplifiers/demodulator_simplifier.h"


class demodulator_tactic_factory : public dependent_expr_simplifier_factory {
public:
dependent_expr_simplifier* mk(ast_manager& m, params_ref const& p, dependent_expr_state& s) override {
return alloc(demodulator_simplifier, m, p, s);
}
};

inline tactic * mk_demodulator_tactic(ast_manager& m, params_ref const& p = params_ref()) {
return alloc(dependent_expr_state_tactic, m, p, alloc(demodulator_tactic_factory));
}

/*
ADD_TACTIC("demodulator", "solve for variables.", "mk_demodulator_tactic(m, p)")
*/


0 comments on commit de916f5

Please sign in to comment.