From d272becade506ef4a247a26a427222c7f2826839 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 12 Aug 2022 11:54:26 -0400 Subject: [PATCH] fixes for division --- src/math/simplex/model_based_opt.cpp | 70 +++++++++++++--------------- 1 file changed, 32 insertions(+), 38 deletions(-) diff --git a/src/math/simplex/model_based_opt.cpp b/src/math/simplex/model_based_opt.cpp index be2b4fe2b02..577ce46330c 100644 --- a/src/math/simplex/model_based_opt.cpp +++ b/src/math/simplex/model_based_opt.cpp @@ -1101,15 +1101,12 @@ namespace opt { // There are only upper or only lower bounds. if (row_index == UINT_MAX) { if (compute_def) { - if (lub_index != UINT_MAX) { - result = solve_for(lub_index, x, true); - } - else if (glb_index != UINT_MAX) { - result = solve_for(glb_index, x, true); - } - else { - result = def() + m_var2value[x]; - } + if (lub_index != UINT_MAX) + result = solve_for(lub_index, x, true); + else if (glb_index != UINT_MAX) + result = solve_for(glb_index, x, true); + else + result = def() + m_var2value[x]; SASSERT(eval(result) == eval(x)); } else { @@ -1122,12 +1119,10 @@ namespace opt { SASSERT(lub_index != UINT_MAX); SASSERT(glb_index != UINT_MAX); if (compute_def) { - if (lub_size <= glb_size) { - result = def(m_rows[lub_index], x); - } - else { - result = def(m_rows[glb_index], x); - } + if (lub_size <= glb_size) + result = def(m_rows[lub_index], x); + else + result = def(m_rows[glb_index], x); } // The number of matching lower and upper bounds is small. @@ -1148,7 +1143,8 @@ namespace opt { } } } - for (unsigned row_id : lub_rows) retire_row(row_id); + for (unsigned row_id : lub_rows) + retire_row(row_id); return result; } @@ -1281,13 +1277,14 @@ namespace opt { // // Given v = a*x + b div m - // Replace x |-> m*y + a_inv*z + // Replace x |-> m*y + z // - w = b div m - // - v = ((m*y + g*z) + b) div m - // = a*y + (a_inv*z + b) div m - // = a*y + b div m + (b mod m + g*z) div m + // - v = ((a*m*y + a*z) + b) div m + // = a*y + (a*z + b) div m + // = a*y + b div m + (b mod m + a*z) div m // = a*y + b div m + k - // where k := (b.value mod m + g*z.value) div m + // where k := (b.value mod m + a*z.value) div m + // k is between 0 and a // model_based_opt::def model_based_opt::solve_div(unsigned x, unsigned_vector const& div_rows, bool compute_def) { def result; @@ -1302,32 +1299,24 @@ namespace opt { replace_var(row_index, x, rational::zero()); rational b_value = m_rows[row_index].m_value; - // compute a_inv - rational a_inv, m_inv; - rational g = gcd(a, m, a_inv, m_inv); - if (a_inv.is_neg()) - a_inv = mod(a_inv, m); - SASSERT(mod(a_inv * a, m) == g); - - // solve for x_value = m*y_value + a^-1*z_value, 0 <= z_value < m. + // solve for x_value = m*y_value + z_value, 0 <= z_value < m. rational z_value = mod(x_value, m); - rational y_value = div(x_value, m) - div(a_inv*z_value, m); - SASSERT(x_value == m*y_value + a_inv*z_value); + rational y_value = div(x_value, m); + SASSERT(x_value == m*y_value + z_value); SASSERT(0 <= z_value && z_value < m); // add new variables unsigned y = add_var(y_value, true); unsigned z = add_var(z_value, true); - // TODO: we could recycle x by either y or z. - // replace x by m*y + a^-1*z in other rows. + // replace x by m*y + z in other rows. unsigned_vector const& row_ids = m_var2row_ids[x]; uint_set visited; visited.insert(row_index); for (unsigned row_id : row_ids) { if (visited.contains(row_id)) continue; - replace_var(row_id, x, m, y, a_inv, z); + replace_var(row_id, x, m, y, rational::one(), z); visited.insert(row_id); normalize(row_id); } @@ -1339,9 +1328,16 @@ namespace opt { // add w = b div m vector coeffs = m_rows[row_index].m_vars; rational coeff = m_rows[row_index].m_coeff; - unsigned w = add_div(coeffs, coeff, m); - rational k = div(g*z_value + mod(b_value, m), m); + + + // + // w = b div m + // v = a*y + w + k + // k = (a*z_value + (b_value mod m)) div m + // + + rational k = div(a*z_value + mod(b_value, m), m); vector div_coeffs; div_coeffs.push_back(var(v, rational::minus_one())); div_coeffs.push_back(var(y, a)); @@ -1358,9 +1354,7 @@ namespace opt { result = (y_def * m) + z_def; m_var2value[x] = eval(result); } - return result; - } //