From 1cf008dd0a4433ed14442822bddd1611ddffac07 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 14 Feb 2024 16:59:52 +0700 Subject: [PATCH] updates --- src/ast/sls/bv_sls.cpp | 4 +- src/ast/sls/bv_sls_eval.cpp | 209 +++++++++++++++++++++++++++++++----- src/ast/sls/bv_sls_eval.h | 8 +- src/ast/sls/sls_valuation.h | 29 +++++ 4 files changed, 218 insertions(+), 32 deletions(-) diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp index 33b63537635..46b2c2183d4 100644 --- a/src/ast/sls/bv_sls.cpp +++ b/src/ast/sls/bv_sls.cpp @@ -79,6 +79,8 @@ namespace bv { bool sls::try_repair_down(app* e) { unsigned n = e->get_num_args(); + if (n == 0) + return false; unsigned s = m_rand(n); for (unsigned i = 0; i < n; ++i) if (try_repair_down(e, (i + s) % n)) @@ -114,7 +116,7 @@ namespace bv { if (m.is_bool(e)) return m_eval.bval0(e) == m_eval.bval1(e); if (bv.is_bv(e)) - return 0 == m_eval.wval0(e).eq(m_eval.wval1(e)); + return m_eval.wval0(e).eq(m_eval.wval1(e)); UNREACHABLE(); return false; } diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/bv_sls_eval.cpp index 990d87dbbed..fa4b60538ef 100644 --- a/src/ast/sls/bv_sls_eval.cpp +++ b/src/ast/sls/bv_sls_eval.cpp @@ -94,6 +94,8 @@ namespace bv { m_tmp2.push_back(0); m_tmp2.push_back(0); m_zero.push_back(0); + m_one.push_back(0); + m_one[0] = 1; } return r; } @@ -207,10 +209,8 @@ namespace bv { for (unsigned i = a.nw; i < 2 * a.nw; ++i) if (m_tmp2[i] != 0) return true; - for (unsigned i = a.bw; i < sizeof(digit_t) * 8 * a.nw; ++i) - if (a.get(m_tmp2, i)) - return true; - return false; + return !a.has_overflow(m_tmp); + return true; }; switch (e->get_decl_kind()) { @@ -247,7 +247,7 @@ namespace bv { auto const& b = wval0(e->get_arg(1)); digit_t c = 0; mpn.add(a.bits.data(), a.nw, b.bits.data(), b.nw, m_tmp.data(), a.nw, &c); - return c != 0; + return c != 0 || a.has_overflow(m_tmp); } case OP_BNEG_OVFL: case OP_BSADD_OVFL: @@ -442,23 +442,110 @@ namespace bv { for (unsigned i = 0; i < a.bw; ++i) val.set(val.bits, i, i + sh < a.bw && b.get(b.bits, i + sh)); if (sign) - for (unsigned i = 0; i < sh; ++i) - val.set(val.bits, a.bw - i, true); + val.set_range(val.bits, 0, a.bw - sh, true); } break; } - case OP_BSDIV: - case OP_BSDIV_I: - case OP_BSDIV0: - case OP_BUDIV: - case OP_BUDIV_I: - case OP_BUDIV0: + case OP_SIGN_EXT: { + auto& a = wval0(e->get_arg(0)); + a.set(val.bits); + bool sign = a.get(a.bits, a.bw - 1); + val.set_range(val.bits, a.bw, val.bw, sign); + break; + } + case OP_ZERO_EXT: { + auto& a = wval0(e->get_arg(0)); + a.set(val.bits); + val.set_range(val.bits, a.bw, val.bw, false); + break; + } case OP_BUREM: case OP_BUREM_I: - case OP_BUREM0: + case OP_BUREM0: { + auto& a = wval0(e->get_arg(0)); + auto& b = wval0(e->get_arg(1)); + + if (b.is_zero()) + val.set(a.bits); + else { + mpn.div(a.bits.data(), a.nw, + b.bits.data(), b.nw, + m_tmp.data(), // quotient + m_tmp2.data()); // remainder + val.set(m_tmp2); + } + break; + } case OP_BSMOD: case OP_BSMOD_I: - case OP_BSMOD0: + case OP_BSMOD0: { + // u = mod(x,y) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + auto& a = wval0(e->get_arg(0)); + auto& b = wval0(e->get_arg(1)); + if (b.is_zero()) + val.set(a.bits); + else { + digit_t c; + mpn.div(a.bits.data(), a.nw, + b.bits.data(), b.nw, + m_tmp.data(), // quotient + m_tmp2.data()); // remainder + if (val.is_zero(m_tmp2)) + val.set(m_tmp2); + else if (a.sign() && b.sign()) + mpn.sub(m_zero.data(), a.nw, m_tmp2.data(), a.nw, m_tmp.data(), &c), + val.set(m_tmp); + else if (a.sign()) + mpn.sub(b.bits.data(), a.nw, m_tmp2.data(), a.nw, m_tmp.data(), &c), + val.set(m_tmp); + else if (b.sign()) + mpn.add(b.bits.data(), a.nw, m_tmp2.data(), a.nw, m_tmp.data(), a.nw, &c), + val.set(m_tmp); + else + val.set(m_tmp2); + } + break; + } + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: { + // x div 0 = -1 + auto& a = wval0(e->get_arg(0)); + auto& b = wval0(e->get_arg(1)); + if (b.is_zero()) { + val.set(m_zero); + for (unsigned i = 0; i < a.nw; ++i) + val.bits[i] = ~val.bits[i]; + } + else { + mpn.div(a.bits.data(), a.nw, + b.bits.data(), b.nw, + m_tmp.data(), // quotient + m_tmp2.data()); // remainder + val.set(m_tmp); + } + break; + } + + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: + // d = udiv(abs(x), abs(y)) + // y = 0, x > 0 -> 1 + // y = 0, x <= 0 -> -1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + + case OP_BREDAND: case OP_BREDOR: case OP_BXNOR: @@ -595,10 +682,18 @@ namespace bv { else return try_repair_sle(!bval0(e), wval0(e, i), wval0(e, 1 - i)); case OP_BASHR: + return try_repair_ashr(wval0(e), wval0(e, 0), wval0(e, 1), i); case OP_BLSHR: + return try_repair_lshr(wval0(e), wval0(e, 0), wval0(e, 1), i); case OP_BSHL: + return try_repair_shl(wval0(e), wval0(e, 0), wval0(e, 1), i); + case OP_BIT2BOOL: { + unsigned idx; + expr* arg; + VERIFY(bv.is_bit2bool(e, arg, idx)); + return try_repair_bit2bool(wval0(e, 0), idx); + } case OP_BCOMP: - case OP_BIT2BOOL: case OP_BNAND: case OP_BREDAND: case OP_BREDOR: @@ -751,15 +846,20 @@ namespace bv { * 8*e = a*(2b), then a = 4e*b^-1 */ bool sls_eval::try_repair_mul(bvval const& e, bvval& a, bvval const& b) { - unsigned parity_e = e.parity(e.bits); - unsigned parity_b = b.parity(b.bits); - if (parity_e < parity_b) + if (b.is_zero()) { + if (a.is_zero()) { + a.set(m_tmp, 1); + return a.try_set(m_tmp); + } return false; + } rational ne, nb; e.get_value(e.bits, ne); b.get_value(b.bits, nb); + unsigned parity_e = e.parity(e.bits); + unsigned parity_b = b.parity(b.bits); if (parity_b > 0) - ne /= rational::power_of_two(parity_b); + ne /= rational::power_of_two(std::min(parity_b, parity_e)); auto inv_b = nb.pseudo_inverse(b.bw); rational na = mod(inv_b * ne, rational::power_of_two(a.bw)); a.set_value(m_tmp, na); @@ -774,7 +874,7 @@ namespace bv { bool sls_eval::try_repair_bneg(bvval const& e, bvval& a) { digit_t c; - mpn.sub(m_zero.data(), e.nw, e.bits.data(), e.nw, m_tmp.data(), &c); + mpn.sub(m_zero.data(), e.nw, e.bits.data(), e.nw, m_tmp.data(), &c); return a.try_set(m_tmp); } @@ -782,10 +882,8 @@ namespace bv { if (e) return a.try_set(b.bits); else { - digit_t c; - a.set(m_zero, 0, true); - mpn.add(b.bits.data(), a.nw, m_zero.data(), a.nw, &c, a.nw, m_tmp.data()); - a.set(m_zero, 0, false); + digit_t c; + mpn.add(b.bits.data(), a.nw, m_one.data(), a.nw, &c, a.nw, m_tmp.data()); return a.try_set(m_tmp); } } @@ -795,18 +893,71 @@ namespace bv { return a.try_set(b.bits); else { digit_t c; - a.set(m_zero, 0, true); - mpn.sub(b.bits.data(), a.nw, m_zero.data(), a.nw, m_tmp.data(), &c); - a.set(m_zero, 0, false); + mpn.sub(b.bits.data(), a.nw, m_one.data(), a.nw, m_tmp.data(), &c); return a.try_set(m_tmp); } } bool sls_eval::try_repair_sle(bool e, bvval& a, bvval const& b) { - return false; + return try_repair_ule(e, a, b); } bool sls_eval::try_repair_sge(bool e, bvval& a, bvval const& b) { + return try_repair_uge(e, a, b); + } + + bool sls_eval::try_repair_bit2bool(bvval& a, unsigned idx) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = a.bits[i]; + a.set(m_tmp, idx, !a.get(a.bits, idx)); + return a.try_set(m_tmp); + } + + bool sls_eval::try_repair_shl(bvval const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + unsigned sh = b.to_nat(b.bits, b.bw); + if (sh == 0) + return a.try_set(e.bits); + else if (sh >= b.bw) { + return false; + } + else { + // + // e = a << sh + // set bw - sh low order bits to bw - sh high-order of e. + // a[bw - sh - 1: 0] = e[bw - 1: sh] + // a[bw - 1: bw - sh] = unchanged + // + for (unsigned i = 0; i < e.bw - sh; ++i) + e.set(m_tmp, i, e.get(e.bits, sh + i)); + for (unsigned i = e.bw - sh; i < e.bw; ++i) + e.set(m_tmp, i, e.get(a.bits, i)); + return a.try_set(m_tmp); + } + } + else { + SASSERT(i == 1); + } + return false; + } + + bool sls_eval::try_repair_ashr(bvval const& e, bvval & a, bvval& b, unsigned i) { + if (i == 0) { + + } + else { + + } + return false; + } + + bool sls_eval::try_repair_lshr(bvval const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + + } + else { + + } return false; } diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/bv_sls_eval.h index a09065735c8..85be653777c 100644 --- a/src/ast/sls/bv_sls_eval.h +++ b/src/ast/sls/bv_sls_eval.h @@ -30,7 +30,7 @@ namespace bv { ast_manager& m; bv_util bv; sls_fixed m_fix; - mpn_manager mpn; + mutable mpn_manager mpn; ptr_vector m_todo; random_gen m_rand; @@ -38,7 +38,7 @@ namespace bv { bool_vector m_eval; // expr-id -> boolean valuation bool_vector m_fixed; // expr-id -> is Boolean fixed - mutable svector m_tmp, m_tmp2, m_zero; + mutable svector m_tmp, m_tmp2, m_zero, m_one; using bvval = sls_valuation; @@ -78,6 +78,10 @@ namespace bv { bool try_repair_uge(bool e, bvval& a, bvval const& b); bool try_repair_sle(bool e, bvval& a, bvval const& b); bool try_repair_sge(bool e, bvval& a, bvval const& b); + bool try_repair_shl(bvval const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_ashr(bvval const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_lshr(bvval const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_bit2bool(bvval& a, unsigned idx); sls_valuation& wval0(app* e, unsigned i) { return wval0(e->get_arg(i)); } diff --git a/src/ast/sls/sls_valuation.h b/src/ast/sls/sls_valuation.h index 1bf61698ce3..4caf1612a8a 100644 --- a/src/ast/sls/sls_valuation.h +++ b/src/ast/sls/sls_valuation.h @@ -54,6 +54,23 @@ namespace bv { return 0 > memcmp(a.data(), b.data(), num_bytes()); } + bool is_zero() const { return is_zero(bits); } + bool is_zero(svector const& a) const { + for (unsigned i = 0; i < nw; ++i) + if (a[i] != 0) + return false; + return true; + } + + bool sign() const { return get(bits, bw - 1); } + + bool has_overflow(svector const& bits) const { + for (unsigned i = bw; i < nw * sizeof(digit_t) * 8; ++i) + if (get(bits, i)) + return true; + return false; + } + unsigned parity(svector const& bits) const { unsigned i = 0; for (; i < bw && !get(bits, i); ++i); @@ -73,16 +90,28 @@ namespace bv { clear_overflow_bits(bits); } + void set_fixed(svector const& src) { for (unsigned i = nw; i-- > 0; ) fixed[i] = src[i]; } + void set_range(svector& dst, unsigned lo, unsigned hi, bool b) { + for (unsigned i = lo; i < hi; ++i) + set(dst, i, b); + } + void set(svector& d, unsigned bit_idx, bool val) const { auto _val = static_cast(0 - static_cast(val)); get_bit_word(d, bit_idx) ^= (_val ^ get_bit_word(d, bit_idx)) & get_pos_mask(bit_idx); } + void set(svector& dst, unsigned v) const { + dst[0] = v; + for (unsigned i = 1; i < nw; ++i) + dst[i] = 0; + } + bool get(svector const& d, unsigned bit_idx) const { return (get_bit_word(d, bit_idx) & get_pos_mask(bit_idx)) != 0; }