Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Mar 5, 2024
1 parent bd323d6 commit 1cf008d
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 32 deletions.
4 changes: 3 additions & 1 deletion src/ast/sls/bv_sls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ namespace bv {

bool sls::try_repair_down(app* e) {
unsigned n = e->get_num_args();
if (n == 0)
return false;
unsigned s = m_rand(n);
for (unsigned i = 0; i < n; ++i)
if (try_repair_down(e, (i + s) % n))
Expand Down Expand Up @@ -114,7 +116,7 @@ namespace bv {
if (m.is_bool(e))
return m_eval.bval0(e) == m_eval.bval1(e);
if (bv.is_bv(e))
return 0 == m_eval.wval0(e).eq(m_eval.wval1(e));
return m_eval.wval0(e).eq(m_eval.wval1(e));
UNREACHABLE();
return false;
}
Expand Down
209 changes: 180 additions & 29 deletions src/ast/sls/bv_sls_eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ namespace bv {
m_tmp2.push_back(0);
m_tmp2.push_back(0);
m_zero.push_back(0);
m_one.push_back(0);
m_one[0] = 1;
}
return r;
}
Expand Down Expand Up @@ -207,10 +209,8 @@ namespace bv {
for (unsigned i = a.nw; i < 2 * a.nw; ++i)
if (m_tmp2[i] != 0)
return true;
for (unsigned i = a.bw; i < sizeof(digit_t) * 8 * a.nw; ++i)
if (a.get(m_tmp2, i))
return true;
return false;
return !a.has_overflow(m_tmp);
return true;
};

switch (e->get_decl_kind()) {
Expand Down Expand Up @@ -247,7 +247,7 @@ namespace bv {
auto const& b = wval0(e->get_arg(1));
digit_t c = 0;
mpn.add(a.bits.data(), a.nw, b.bits.data(), b.nw, m_tmp.data(), a.nw, &c);
return c != 0;
return c != 0 || a.has_overflow(m_tmp);
}
case OP_BNEG_OVFL:
case OP_BSADD_OVFL:
Expand Down Expand Up @@ -442,23 +442,110 @@ namespace bv {
for (unsigned i = 0; i < a.bw; ++i)
val.set(val.bits, i, i + sh < a.bw && b.get(b.bits, i + sh));
if (sign)
for (unsigned i = 0; i < sh; ++i)
val.set(val.bits, a.bw - i, true);
val.set_range(val.bits, 0, a.bw - sh, true);
}
break;
}
case OP_BSDIV:
case OP_BSDIV_I:
case OP_BSDIV0:
case OP_BUDIV:
case OP_BUDIV_I:
case OP_BUDIV0:
case OP_SIGN_EXT: {
auto& a = wval0(e->get_arg(0));
a.set(val.bits);
bool sign = a.get(a.bits, a.bw - 1);
val.set_range(val.bits, a.bw, val.bw, sign);
break;
}
case OP_ZERO_EXT: {
auto& a = wval0(e->get_arg(0));
a.set(val.bits);
val.set_range(val.bits, a.bw, val.bw, false);
break;
}
case OP_BUREM:
case OP_BUREM_I:
case OP_BUREM0:
case OP_BUREM0: {
auto& a = wval0(e->get_arg(0));
auto& b = wval0(e->get_arg(1));

if (b.is_zero())
val.set(a.bits);
else {
mpn.div(a.bits.data(), a.nw,
b.bits.data(), b.nw,
m_tmp.data(), // quotient
m_tmp2.data()); // remainder
val.set(m_tmp2);
}
break;
}
case OP_BSMOD:
case OP_BSMOD_I:
case OP_BSMOD0:
case OP_BSMOD0: {
// u = mod(x,y)
// u = 0 -> 0
// y = 0 -> x
// x < 0, y < 0 -> -u
// x < 0, y >= 0 -> y - u
// x >= 0, y < 0 -> y + u
// x >= 0, y >= 0 -> u
auto& a = wval0(e->get_arg(0));
auto& b = wval0(e->get_arg(1));
if (b.is_zero())
val.set(a.bits);
else {
digit_t c;
mpn.div(a.bits.data(), a.nw,
b.bits.data(), b.nw,
m_tmp.data(), // quotient
m_tmp2.data()); // remainder
if (val.is_zero(m_tmp2))
val.set(m_tmp2);
else if (a.sign() && b.sign())
mpn.sub(m_zero.data(), a.nw, m_tmp2.data(), a.nw, m_tmp.data(), &c),
val.set(m_tmp);
else if (a.sign())
mpn.sub(b.bits.data(), a.nw, m_tmp2.data(), a.nw, m_tmp.data(), &c),
val.set(m_tmp);
else if (b.sign())
mpn.add(b.bits.data(), a.nw, m_tmp2.data(), a.nw, m_tmp.data(), a.nw, &c),
val.set(m_tmp);
else
val.set(m_tmp2);
}
break;
}
case OP_BUDIV:
case OP_BUDIV_I:
case OP_BUDIV0: {
// x div 0 = -1
auto& a = wval0(e->get_arg(0));
auto& b = wval0(e->get_arg(1));
if (b.is_zero()) {
val.set(m_zero);
for (unsigned i = 0; i < a.nw; ++i)
val.bits[i] = ~val.bits[i];
}
else {
mpn.div(a.bits.data(), a.nw,
b.bits.data(), b.nw,
m_tmp.data(), // quotient
m_tmp2.data()); // remainder
val.set(m_tmp);
}
break;
}

case OP_BSDIV:
case OP_BSDIV_I:
case OP_BSDIV0:
// d = udiv(abs(x), abs(y))
// y = 0, x > 0 -> 1
// y = 0, x <= 0 -> -1
// x = 0, y != 0 -> 0
// x > 0, y < 0 -> -d
// x < 0, y > 0 -> -d
// x > 0, y > 0 -> d
// x < 0, y < 0 -> d


case OP_BREDAND:
case OP_BREDOR:
case OP_BXNOR:
Expand Down Expand Up @@ -595,10 +682,18 @@ namespace bv {
else
return try_repair_sle(!bval0(e), wval0(e, i), wval0(e, 1 - i));
case OP_BASHR:
return try_repair_ashr(wval0(e), wval0(e, 0), wval0(e, 1), i);
case OP_BLSHR:
return try_repair_lshr(wval0(e), wval0(e, 0), wval0(e, 1), i);
case OP_BSHL:
return try_repair_shl(wval0(e), wval0(e, 0), wval0(e, 1), i);
case OP_BIT2BOOL: {
unsigned idx;
expr* arg;
VERIFY(bv.is_bit2bool(e, arg, idx));
return try_repair_bit2bool(wval0(e, 0), idx);
}
case OP_BCOMP:
case OP_BIT2BOOL:
case OP_BNAND:
case OP_BREDAND:
case OP_BREDOR:
Expand Down Expand Up @@ -751,15 +846,20 @@ namespace bv {
* 8*e = a*(2b), then a = 4e*b^-1
*/
bool sls_eval::try_repair_mul(bvval const& e, bvval& a, bvval const& b) {
unsigned parity_e = e.parity(e.bits);
unsigned parity_b = b.parity(b.bits);
if (parity_e < parity_b)
if (b.is_zero()) {
if (a.is_zero()) {
a.set(m_tmp, 1);
return a.try_set(m_tmp);
}
return false;
}
rational ne, nb;
e.get_value(e.bits, ne);
b.get_value(b.bits, nb);
unsigned parity_e = e.parity(e.bits);
unsigned parity_b = b.parity(b.bits);
if (parity_b > 0)
ne /= rational::power_of_two(parity_b);
ne /= rational::power_of_two(std::min(parity_b, parity_e));
auto inv_b = nb.pseudo_inverse(b.bw);
rational na = mod(inv_b * ne, rational::power_of_two(a.bw));
a.set_value(m_tmp, na);
Expand All @@ -774,18 +874,16 @@ namespace bv {

bool sls_eval::try_repair_bneg(bvval const& e, bvval& a) {
digit_t c;
mpn.sub(m_zero.data(), e.nw, e.bits.data(), e.nw, m_tmp.data(), &c);
mpn.sub(m_zero.data(), e.nw, e.bits.data(), e.nw, m_tmp.data(), &c);
return a.try_set(m_tmp);
}

bool sls_eval::try_repair_ule(bool e, bvval& a, bvval const& b) {
if (e)
return a.try_set(b.bits);
else {
digit_t c;
a.set(m_zero, 0, true);
mpn.add(b.bits.data(), a.nw, m_zero.data(), a.nw, &c, a.nw, m_tmp.data());
a.set(m_zero, 0, false);
digit_t c;
mpn.add(b.bits.data(), a.nw, m_one.data(), a.nw, &c, a.nw, m_tmp.data());
return a.try_set(m_tmp);
}
}
Expand All @@ -795,18 +893,71 @@ namespace bv {
return a.try_set(b.bits);
else {
digit_t c;
a.set(m_zero, 0, true);
mpn.sub(b.bits.data(), a.nw, m_zero.data(), a.nw, m_tmp.data(), &c);
a.set(m_zero, 0, false);
mpn.sub(b.bits.data(), a.nw, m_one.data(), a.nw, m_tmp.data(), &c);
return a.try_set(m_tmp);
}
}

bool sls_eval::try_repair_sle(bool e, bvval& a, bvval const& b) {
return false;
return try_repair_ule(e, a, b);
}

bool sls_eval::try_repair_sge(bool e, bvval& a, bvval const& b) {
return try_repair_uge(e, a, b);
}

bool sls_eval::try_repair_bit2bool(bvval& a, unsigned idx) {
for (unsigned i = 0; i < a.nw; ++i)
m_tmp[i] = a.bits[i];
a.set(m_tmp, idx, !a.get(a.bits, idx));
return a.try_set(m_tmp);
}

bool sls_eval::try_repair_shl(bvval const& e, bvval& a, bvval& b, unsigned i) {
if (i == 0) {
unsigned sh = b.to_nat(b.bits, b.bw);
if (sh == 0)
return a.try_set(e.bits);
else if (sh >= b.bw) {
return false;
}
else {
//
// e = a << sh
// set bw - sh low order bits to bw - sh high-order of e.
// a[bw - sh - 1: 0] = e[bw - 1: sh]
// a[bw - 1: bw - sh] = unchanged
//
for (unsigned i = 0; i < e.bw - sh; ++i)
e.set(m_tmp, i, e.get(e.bits, sh + i));
for (unsigned i = e.bw - sh; i < e.bw; ++i)
e.set(m_tmp, i, e.get(a.bits, i));
return a.try_set(m_tmp);
}
}
else {
SASSERT(i == 1);
}
return false;
}

bool sls_eval::try_repair_ashr(bvval const& e, bvval & a, bvval& b, unsigned i) {
if (i == 0) {

}
else {

}
return false;
}

bool sls_eval::try_repair_lshr(bvval const& e, bvval& a, bvval& b, unsigned i) {
if (i == 0) {

}
else {

}
return false;
}

Expand Down
8 changes: 6 additions & 2 deletions src/ast/sls/bv_sls_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ namespace bv {
ast_manager& m;
bv_util bv;
sls_fixed m_fix;
mpn_manager mpn;
mutable mpn_manager mpn;
ptr_vector<expr> m_todo;
random_gen m_rand;

scoped_ptr_vector<sls_valuation> m_values0, m_values1; // expr-id -> bv valuation
bool_vector m_eval; // expr-id -> boolean valuation
bool_vector m_fixed; // expr-id -> is Boolean fixed

mutable svector<digit_t> m_tmp, m_tmp2, m_zero;
mutable svector<digit_t> m_tmp, m_tmp2, m_zero, m_one;

using bvval = sls_valuation;

Expand Down Expand Up @@ -78,6 +78,10 @@ namespace bv {
bool try_repair_uge(bool e, bvval& a, bvval const& b);
bool try_repair_sle(bool e, bvval& a, bvval const& b);
bool try_repair_sge(bool e, bvval& a, bvval const& b);
bool try_repair_shl(bvval const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_ashr(bvval const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_lshr(bvval const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_bit2bool(bvval& a, unsigned idx);

sls_valuation& wval0(app* e, unsigned i) { return wval0(e->get_arg(i)); }

Expand Down
29 changes: 29 additions & 0 deletions src/ast/sls/sls_valuation.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ namespace bv {
return 0 > memcmp(a.data(), b.data(), num_bytes());
}

bool is_zero() const { return is_zero(bits); }
bool is_zero(svector<digit_t> const& a) const {
for (unsigned i = 0; i < nw; ++i)
if (a[i] != 0)
return false;
return true;
}

bool sign() const { return get(bits, bw - 1); }

bool has_overflow(svector<digit_t> const& bits) const {
for (unsigned i = bw; i < nw * sizeof(digit_t) * 8; ++i)
if (get(bits, i))
return true;
return false;
}

unsigned parity(svector<digit_t> const& bits) const {
unsigned i = 0;
for (; i < bw && !get(bits, i); ++i);
Expand All @@ -73,16 +90,28 @@ namespace bv {
clear_overflow_bits(bits);
}


void set_fixed(svector<digit_t> const& src) {
for (unsigned i = nw; i-- > 0; )
fixed[i] = src[i];
}

void set_range(svector<digit_t>& dst, unsigned lo, unsigned hi, bool b) {
for (unsigned i = lo; i < hi; ++i)
set(dst, i, b);
}

void set(svector<digit_t>& d, unsigned bit_idx, bool val) const {
auto _val = static_cast<digit_t>(0 - static_cast<digit_t>(val));
get_bit_word(d, bit_idx) ^= (_val ^ get_bit_word(d, bit_idx)) & get_pos_mask(bit_idx);
}

void set(svector<digit_t>& dst, unsigned v) const {
dst[0] = v;
for (unsigned i = 1; i < nw; ++i)
dst[i] = 0;
}

bool get(svector<digit_t> const& d, unsigned bit_idx) const {
return (get_bit_word(d, bit_idx) & get_pos_mask(bit_idx)) != 0;
}
Expand Down

0 comments on commit 1cf008d

Please sign in to comment.