From 7dc4ce8259b6352445601b8103ab121adcbb1ee2 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 20 Feb 2024 01:40:27 -0800 Subject: [PATCH] use tuned gcd to compute mult inverse --- src/ast/sls/bv_sls.cpp | 3 +- src/ast/sls/bv_sls_eval.cpp | 94 +++++++++++++++++++++++++++++++++-- src/ast/sls/bv_sls_eval.h | 1 + src/ast/sls/sls_valuation.cpp | 21 ++++++-- src/ast/sls/sls_valuation.h | 17 +++++-- 5 files changed, 125 insertions(+), 11 deletions(-) diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp index 7476c56e92f..d63bd4f0f72 100644 --- a/src/ast/sls/bv_sls.cpp +++ b/src/ast/sls/bv_sls.cpp @@ -50,7 +50,7 @@ namespace bv { void sls::reinit_eval() { std::function eval = [&](expr* e, unsigned i) { auto should_keep = [&]() { - return m_rand() % 100 >= 95; + return m_rand() % 100 >= 98; }; if (m.is_bool(e)) { if (m_eval.is_fixed0(e) || should_keep()) @@ -225,5 +225,6 @@ namespace bv { void sls::updt_params(params_ref const& _p) { sls_params p(_p); m_config.m_max_restarts = p.max_restarts(); + m_rand.set_seed(p.random_seed()); } } diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/bv_sls_eval.cpp index 7bd8316b5ad..c01acf88a8b 100644 --- a/src/ast/sls/bv_sls_eval.cpp +++ b/src/ast/sls/bv_sls_eval.cpp @@ -96,6 +96,11 @@ namespace bv { m_tmp4.push_back(0); m_zero.push_back(0); m_one.push_back(0); + m_a.push_back(0); + m_b.push_back(0); + m_nexta.push_back(0); + m_nextb.push_back(0); + m_aux.push_back(0); m_minus_one.push_back(~0); m_one[0] = 1; } @@ -1011,17 +1016,98 @@ namespace bv { } return false; } + + unsigned parity_e = e.parity(e.bits); + unsigned parity_b = b.parity(b.bits); + +#if 1 + + auto& x = m_tmp; + auto& y = m_tmp2; + auto& quot = m_tmp3; + auto& rem = m_tmp4; + auto& ta = m_a; + auto& tb = m_b; + auto& nexta = m_nexta; + auto& nextb = m_nextb; + auto& aux = m_aux; + + + // x*ta + y*tb = x + b.get(y); + if (parity_b > 0) + b.shift_right(y, parity_b); + y[a.nw] = 0; + a.nw = a.nw + 1; + a.bw = 8 * sizeof(digit_t) * a.nw; + // x = 2 ^ b.bw + a.set_zero(x); + a.set(x, b.bw, true); + + a.set_one(ta); + a.set_zero(tb); + a.set_zero(nexta); + a.set_one(nextb); + + rem.reserve(2 * a.nw); + SASSERT(a.le(y, x)); + while (a.gt(y, m_zero)) { + SASSERT(a.le(y, x)); + set_div(x, y, a.bw, quot, rem); // quot, rem := quot_rem(x, y) + SASSERT(a.le(rem, y)); + a.set(x, y); // x := y + a.set(y, rem); // y := rem + a.set(aux, nexta); // aux := nexta + a.set_mul(rem, quot, nexta, false); + a.set_sub(nexta, ta, rem); // nexta := ta - quot*nexta + a.set(ta, aux); // ta := aux + a.set(aux, nextb); // aux := nextb + a.set_mul(rem, quot, nextb, false); + a.set_sub(nextb, tb, rem); // nextb := tb - quot*nextb + a.set(tb, aux); // tb := aux + } + + a.bw = b.bw; + a.nw = b.nw; + // x*a + y*b = 1 + +#if Z3DEBUG + b.get(y); + if (parity_b > 0) + b.shift_right(y, parity_b); + a.set_mul(m_tmp, tb, y); +#if 0 + for (unsigned i = a.nw; i-- > 0; ) + verbose_stream() << tb[i]; + verbose_stream() << "\n"; + for (unsigned i = a.nw; i-- > 0; ) + verbose_stream() << y[i]; + verbose_stream() << "\n"; + for (unsigned i = a.nw; i-- > 0; ) + verbose_stream() << m_tmp[i]; + verbose_stream() << "\n"; +#endif + SASSERT(b.is_one(m_tmp)); +#endif + e.get(m_tmp2); + if (parity_e > 0 && parity_b > 0) + b.shift_right(m_tmp2, std::min(parity_b, parity_e)); + a.set_mul(m_tmp, tb, m_tmp2); + a.set_repair(random_bool(), m_tmp); + +#else + rational ne, nb; e.get_value(e.bits, ne); b.get_value(b.bits, nb); - unsigned parity_e = e.parity(e.bits); - unsigned parity_b = b.parity(b.bits); + if (parity_b > 0) ne /= rational::power_of_two(std::min(parity_b, parity_e)); auto inv_b = nb.pseudo_inverse(b.bw); rational na = mod(inv_b * ne, rational::power_of_two(a.bw)); a.set_value(m_tmp, na); a.set_repair(random_bool(), m_tmp); +#endif return true; } @@ -1454,7 +1540,9 @@ namespace bv { } quot[nw - 1] = (1 << (bw % (8 * sizeof(digit_t)))) - 1; } - else { + else { + for (unsigned i = 0; i < nw; ++i) + rem[i] = quot[i] = 0; mpn.div(a.data(), nw, b.data(), bnw, quot.data(), rem.data()); } } diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/bv_sls_eval.h index d2c20e6add7..bdd6db0ed2f 100644 --- a/src/ast/sls/bv_sls_eval.h +++ b/src/ast/sls/bv_sls_eval.h @@ -40,6 +40,7 @@ namespace bv { bool_vector m_fixed; // expr-id -> is Boolean fixed mutable svector m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one; + svector m_a, m_b, m_nextb, m_nexta, m_aux; using bvval = sls_valuation; diff --git a/src/ast/sls/sls_valuation.cpp b/src/ast/sls/sls_valuation.cpp index 1954b177774..db977cc2e00 100644 --- a/src/ast/sls/sls_valuation.cpp +++ b/src/ast/sls/sls_valuation.cpp @@ -291,6 +291,15 @@ namespace bv { return value; } + void sls_valuation::shift_right(svector& out, unsigned shift) const { + SASSERT(shift < bw); + unsigned n = shift / (8 * sizeof(digit_t)); + unsigned s = shift % (8 * sizeof(digit_t)); + for (unsigned i = 0; i < bw; ++i) + set(out, i, i + shift < bw ? get(bits, i + shift) : false); + SASSERT(!has_overflow(out)); + } + void sls_valuation::add_range(rational l, rational h) { l = mod(l, rational::power_of_two(bw)); h = mod(h, rational::power_of_two(bw)); @@ -427,11 +436,15 @@ namespace bv { return ovfl; } - bool sls_valuation::set_mul(svector& out, svector const& a, svector const& b) const { + bool sls_valuation::set_mul(svector& out, svector const& a, svector const& b, bool check_overflow) const { mpn_manager().mul(a.data(), nw, b.data(), nw, out.data()); - bool ovfl = has_overflow(out); - for (unsigned i = nw; i < 2 * nw; ++i) - ovfl |= out[i] != 0; + + bool ovfl = false; + if (check_overflow) { + ovfl = has_overflow(out); + for (unsigned i = nw; i < 2 * nw; ++i) + ovfl |= out[i] != 0; + } clear_overflow_bits(out); return ovfl; } diff --git a/src/ast/sls/sls_valuation.h b/src/ast/sls/sls_valuation.h index bb8100f7546..a1ec5257fb7 100644 --- a/src/ast/sls/sls_valuation.h +++ b/src/ast/sls/sls_valuation.h @@ -131,9 +131,19 @@ namespace bv { clear_overflow_bits(bits); } - void set_zero() { + void set_zero(svector& out) const { for (unsigned i = 0; i < nw; ++i) - bits[i] = 0; + out[i] = 0; + } + + void set_one(svector& out) const { + for (unsigned i = 1; i < nw; ++i) + out[i] = 0; + out[0] = 1; + } + + void set_zero() { + set_zero(bits); } void sub1(svector& out) const { @@ -149,7 +159,8 @@ namespace bv { void set_sub(svector& out, svector const& a, svector const& b) const; bool set_add(svector& out, svector const& a, svector const& b) const; - bool set_mul(svector& out, svector const& a, svector const& b) const; + bool set_mul(svector& out, svector const& a, svector const& b, bool check_overflow = true) const; + void shift_right(svector& out, unsigned shift) const; void set_range(svector& dst, unsigned lo, unsigned hi, bool b) { for (unsigned i = lo; i < hi; ++i)