diff --git a/bench-multiplication/assoc.c b/bench-multiplication/assoc.c new file mode 100644 index 00000000000..37c8c80a309 --- /dev/null +++ b/bench-multiplication/assoc.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] a, b, c; + __CPROVER_bitvector[BW] ab = a * b, bc = b * c; + __CPROVER_assert(ab * c == a * bc, "associativity"); +} diff --git a/bench-multiplication/comm.c b/bench-multiplication/comm.c new file mode 100644 index 00000000000..47911b5e7ff --- /dev/null +++ b/bench-multiplication/comm.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] a, b; + __CPROVER_bitvector[BW] c = a * b, d = b * a; + __CPROVER_assert(c == d, "commutativity"); +} diff --git a/bench-multiplication/const3.c b/bench-multiplication/const3.c new file mode 100644 index 00000000000..ae21fd00374 --- /dev/null +++ b/bench-multiplication/const3.c @@ -0,0 +1,7 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] a; + __CPROVER_assert(a * 3 == a + a + a, "multiply by 3"); +} diff --git a/bench-multiplication/distrib.c b/bench-multiplication/distrib.c new file mode 100644 index 00000000000..57e79cee6d3 --- /dev/null +++ b/bench-multiplication/distrib.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] a, b, c; + __CPROVER_bitvector[BW] lhs = a * (b + c), rhs = a * b + a * c; + __CPROVER_assert(lhs == rhs, "distributivity"); +} diff --git a/bench-multiplication/factor.c b/bench-multiplication/factor.c new file mode 100644 index 00000000000..48391e34b39 --- /dev/null +++ b/bench-multiplication/factor.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_assert(p * q != 143, "not factorable"); +} diff --git a/bench-multiplication/mul_bounds.c b/bench-multiplication/mul_bounds.c new file mode 100644 index 00000000000..6ecfaa7d0b3 --- /dev/null +++ b/bench-multiplication/mul_bounds.c @@ -0,0 +1,9 @@ +#ifndef BW +#define BW 16 +#endif +int main() { + __CPROVER_bitvector[BW/2] a, b; + __CPROVER_bitvector[BW] wide = (__CPROVER_bitvector[BW])a * (__CPROVER_bitvector[BW])b; + __CPROVER_bitvector[BW*2] wider = (__CPROVER_bitvector[BW*2])a * (__CPROVER_bitvector[BW*2])b; + __CPROVER_assert(wide == (__CPROVER_bitvector[BW])wider, "no overflow"); +} diff --git a/bench-multiplication/mul_double.c b/bench-multiplication/mul_double.c new file mode 100644 index 00000000000..a379ab2a7b3 --- /dev/null +++ b/bench-multiplication/mul_double.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 16 +#endif +int main() { + __CPROVER_bitvector[BW] a, b; + __CPROVER_bitvector[BW/2] a_lo = a, b_lo = b, prod_lo = a * b; + __CPROVER_assert(prod_lo == (__CPROVER_bitvector[BW/2])(a_lo * b_lo), "low half"); +} diff --git a/bench-multiplication/mul_negation.c b/bench-multiplication/mul_negation.c new file mode 100644 index 00000000000..8def321741e --- /dev/null +++ b/bench-multiplication/mul_negation.c @@ -0,0 +1,7 @@ +#ifndef BW +#define BW 16 +#endif +int main() { + __CPROVER_bitvector[BW] a; + __CPROVER_assert(a * (__CPROVER_bitvector[BW])(-1) == -a, "mul by -1"); +} diff --git a/bench-multiplication/mul_overflow.c b/bench-multiplication/mul_overflow.c new file mode 100644 index 00000000000..da7dc0bb268 --- /dev/null +++ b/bench-multiplication/mul_overflow.c @@ -0,0 +1,9 @@ +#ifndef BW +#define BW 16 +#endif +int main() { + __CPROVER_bitvector[BW] a, b; + __CPROVER_bitvector[BW*2] wide = (__CPROVER_bitvector[BW*2])a * (__CPROVER_bitvector[BW*2])b; + __CPROVER_bitvector[BW] narrow = a * b; + __CPROVER_assert(narrow == (__CPROVER_bitvector[BW])wide, "truncated matches wide"); +} diff --git a/bench-multiplication/mul_shift.c b/bench-multiplication/mul_shift.c new file mode 100644 index 00000000000..9d03df531be --- /dev/null +++ b/bench-multiplication/mul_shift.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 16 +#endif +int main() { + __CPROVER_bitvector[BW] a; + __CPROVER_assert(a * 4 == a << 2, "mul4 == shl2"); + __CPROVER_assert(a * 8 == a << 3, "mul8 == shl3"); +} diff --git a/bench-multiplication/mul_square_nonneg.c b/bench-multiplication/mul_square_nonneg.c new file mode 100644 index 00000000000..21bc09dd543 --- /dev/null +++ b/bench-multiplication/mul_square_nonneg.c @@ -0,0 +1,8 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] a; + __CPROVER_bitvector[BW*2] wide = (__CPROVER_bitvector[BW*2])a * (__CPROVER_bitvector[BW*2])a; + __CPROVER_assert(wide >= a, "square >= original"); +} diff --git a/bench-multiplication/mul_zero_factor.c b/bench-multiplication/mul_zero_factor.c new file mode 100644 index 00000000000..ec20a720f01 --- /dev/null +++ b/bench-multiplication/mul_zero_factor.c @@ -0,0 +1,11 @@ +#ifndef BW +#define BW 8 +#endif +int main() { + __CPROVER_bitvector[BW] a, b, r; + r = a * b; + __CPROVER_assume(r == 0); + __CPROVER_assume(a != 0); + __CPROVER_assume(b != 0); + __CPROVER_assert(0, "found zero divisors"); +} diff --git a/bench-multiplication/smt-comp/assoc_8.smt2 b/bench-multiplication/smt-comp/assoc_8.smt2 new file mode 100644 index 00000000000..fda627eda7e --- /dev/null +++ b/bench-multiplication/smt-comp/assoc_8.smt2 @@ -0,0 +1,7 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 8)) +(declare-fun b () (_ BitVec 8)) +(declare-fun c () (_ BitVec 8)) +(assert (not (= (bvmul (bvmul a b) c) (bvmul a (bvmul b c))))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/comm_16.smt2 b/bench-multiplication/smt-comp/comm_16.smt2 new file mode 100644 index 00000000000..fa6b262d581 --- /dev/null +++ b/bench-multiplication/smt-comp/comm_16.smt2 @@ -0,0 +1,6 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 16)) +(declare-fun b () (_ BitVec 16)) +(assert (not (= (bvmul a b) (bvmul b a)))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/comm_32.smt2 b/bench-multiplication/smt-comp/comm_32.smt2 new file mode 100644 index 00000000000..c3374d7749e --- /dev/null +++ b/bench-multiplication/smt-comp/comm_32.smt2 @@ -0,0 +1,6 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 32)) +(declare-fun b () (_ BitVec 32)) +(assert (not (= (bvmul a b) (bvmul b a)))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/comm_8.smt2 b/bench-multiplication/smt-comp/comm_8.smt2 new file mode 100644 index 00000000000..2ef800c494a --- /dev/null +++ b/bench-multiplication/smt-comp/comm_8.smt2 @@ -0,0 +1,6 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 8)) +(declare-fun b () (_ BitVec 8)) +(assert (not (= (bvmul a b) (bvmul b a)))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/distrib_16.smt2 b/bench-multiplication/smt-comp/distrib_16.smt2 new file mode 100644 index 00000000000..7102db62f31 --- /dev/null +++ b/bench-multiplication/smt-comp/distrib_16.smt2 @@ -0,0 +1,7 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 16)) +(declare-fun b () (_ BitVec 16)) +(declare-fun c () (_ BitVec 16)) +(assert (not (= (bvmul a (bvadd b c)) (bvadd (bvmul a b) (bvmul a c))))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/distrib_8.smt2 b/bench-multiplication/smt-comp/distrib_8.smt2 new file mode 100644 index 00000000000..fe74efbfb21 --- /dev/null +++ b/bench-multiplication/smt-comp/distrib_8.smt2 @@ -0,0 +1,7 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 8)) +(declare-fun b () (_ BitVec 8)) +(declare-fun c () (_ BitVec 8)) +(assert (not (= (bvmul a (bvadd b c)) (bvadd (bvmul a b) (bvmul a c))))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/factor_12.smt2 b/bench-multiplication/smt-comp/factor_12.smt2 new file mode 100644 index 00000000000..54ca65407a3 --- /dev/null +++ b/bench-multiplication/smt-comp/factor_12.smt2 @@ -0,0 +1,8 @@ +(set-logic QF_BV) +(declare-fun p () (_ BitVec 12)) +(declare-fun q () (_ BitVec 12)) +(assert (bvugt p (_ bv1 12))) +(assert (bvugt q (_ bv1 12))) +(assert (= (bvmul p q) (_ bv143 12))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/factor_16.smt2 b/bench-multiplication/smt-comp/factor_16.smt2 new file mode 100644 index 00000000000..3512aa963d8 --- /dev/null +++ b/bench-multiplication/smt-comp/factor_16.smt2 @@ -0,0 +1,8 @@ +(set-logic QF_BV) +(declare-fun p () (_ BitVec 16)) +(declare-fun q () (_ BitVec 16)) +(assert (bvugt p (_ bv1 16))) +(assert (bvugt q (_ bv1 16))) +(assert (= (bvmul p q) (_ bv10403 16))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/factor_20.smt2 b/bench-multiplication/smt-comp/factor_20.smt2 new file mode 100644 index 00000000000..0ecd88784e2 --- /dev/null +++ b/bench-multiplication/smt-comp/factor_20.smt2 @@ -0,0 +1,8 @@ +(set-logic QF_BV) +(declare-fun p () (_ BitVec 20)) +(declare-fun q () (_ BitVec 20)) +(assert (bvugt p (_ bv1 20))) +(assert (bvugt q (_ bv1 20))) +(assert (= (bvmul p q) (_ bv101101 20))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/mixed_arith_8.smt2 b/bench-multiplication/smt-comp/mixed_arith_8.smt2 new file mode 100644 index 00000000000..f890fe55eb2 --- /dev/null +++ b/bench-multiplication/smt-comp/mixed_arith_8.smt2 @@ -0,0 +1,9 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 8)) +(declare-fun b () (_ BitVec 8)) +(declare-fun c () (_ BitVec 8)) +(declare-fun d () (_ BitVec 8)) +(assert (not (= (bvadd (bvmul a b) (bvmul c d)) + (bvsub (bvsub (bvmul (bvadd a c) (bvadd b d)) (bvmul a d)) (bvmul c b))))) +(check-sat) +(exit) diff --git a/bench-multiplication/smt-comp/mul_no_overflow_16.smt2 b/bench-multiplication/smt-comp/mul_no_overflow_16.smt2 new file mode 100644 index 00000000000..1977684bb62 --- /dev/null +++ b/bench-multiplication/smt-comp/mul_no_overflow_16.smt2 @@ -0,0 +1,11 @@ +(set-logic QF_BV) +(declare-fun a () (_ BitVec 16)) +(declare-fun b () (_ BitVec 16)) +(declare-fun r () (_ BitVec 16)) +(assert (= r (bvmul a b))) +; Check: if a != 0 then r/a == b (no overflow case) +(assert (not (= a (_ bv0 16)))) +(assert (not (= (bvudiv r a) b))) +; This is SAT when overflow occurs +(check-sat) +(exit) diff --git a/bench-multiplication/square.c b/bench-multiplication/square.c new file mode 100644 index 00000000000..5bd0c0419ef --- /dev/null +++ b/bench-multiplication/square.c @@ -0,0 +1,9 @@ +#ifndef BW +#define BW 9 +#endif +int main() { + __CPROVER_bitvector[BW] a, b, s = a + b; + __CPROVER_bitvector[BW] lhs = s * s; + __CPROVER_bitvector[BW] rhs = a*a + (__CPROVER_bitvector[BW])2*a*b + b*b; + __CPROVER_assert(lhs == rhs, "square identity"); +} diff --git a/doc/architectural/multiplication-encoding-research.md b/doc/architectural/multiplication-encoding-research.md new file mode 100644 index 00000000000..3e58afcb38f --- /dev/null +++ b/doc/architectural/multiplication-encoding-research.md @@ -0,0 +1,591 @@ +# Multiplication Encoding Research Notes + +## Objective + +Find a propositional encoding of multiplication that SAT solvers can reason +about as efficiently as possible. This is an open research problem — even +state-of-the-art SMT solvers (Bitwuzla, Z3, CVC5) struggle with +multiplication at moderate bitwidths. + +## Benchmark Suite + +### Commutativity: `a * b == b * a` +Tests whether the solver can prove the algebraic identity. Uses +`__CPROVER_bitvector[N]` for exact bitwidth control. + +```c +// multiply-comm.c +int main() { + __CPROVER_bitvector[BITWIDTH] a, b; + __CPROVER_bitvector[BITWIDTH] c = a * b; + __CPROVER_bitvector[BITWIDTH] d = b * a; + __CPROVER_assert(c == d, "commutativity"); +} +``` + +### Associativity: `(a * b) * c == a * (b * c)` +Much harder — involves three multiplications. + +```c +// multiply-assoc.c +int main() { + __CPROVER_bitvector[BITWIDTH] a, b, c; + __CPROVER_bitvector[BITWIDTH] ab = a * b; + __CPROVER_bitvector[BITWIDTH] bc = b * c; + __CPROVER_assert(ab * c == a * bc, "associativity"); +} +``` + +### Distributivity: `a * (b + c) == a * b + a * c` + +```c +// multiply-distrib.c +int main() { + __CPROVER_bitvector[BITWIDTH] a, b, c; + __CPROVER_assert(a * (b + c) == a * b + a * c, "distributivity"); +} +``` + +### Factoring: given `n`, find `p * q == n` +Tests the reverse direction — SAT solver must find factors. + +```c +// multiply-factor.c +int main() { + __CPROVER_bitvector[BITWIDTH] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_assert(p * q != COMPOSITE, "not factorable"); +} +``` + +### Multiply-by-constant: `a * K == a + a + ... + a` + +```c +// multiply-const.c +int main() { + __CPROVER_bitvector[BITWIDTH] a; + __CPROVER_assert(a * 3 == a + a + a, "multiply by 3"); +} +``` + +### Square identity: `(a + b)^2 == a^2 + 2*a*b + b^2` + +```c +// multiply-square.c +int main() { + __CPROVER_bitvector[BITWIDTH] a, b; + __CPROVER_bitvector[BITWIDTH] sum = a + b; + __CPROVER_assert(sum * sum == a*a + 2*a*b + b*b, "square identity"); +} +``` + +## Encoding Schemes Tested + +### 1. Baseline (shift-add with ripple-carry) +The default CBMC encoding. Each bit of one operand selects (AND) the other +operand, shifted by the bit position. Partial products are summed using +ripple-carry addition. + +- Formula size: O(n²) variables and clauses +- Carry chain depth: O(n²) — long dependency chains +- Uses propagation-complete full adder (14 clauses) + +### 2. Dadda Tree +Reduces partial products using Dadda's sequence of full/half adders to +minimize the number of adders used. Produces two rows (carry-save form), +then does a final ripple-carry addition. + +- Same formula size as baseline (same number of full adders, different wiring) +- Different clause ordering affects solver heuristics +- Better than baseline with CaDiCaL on large bitwidths (see code comments) + +### 3. Wallace Tree +Similar to Dadda but reduces columns as aggressively as possible at each +stage. Slightly larger than Dadda. + +- Generally worse than Dadda in benchmarks + +### 4. Comba (popcount-based column reduction) +**Current best performer.** Reduces each column independently using the +parallel bit-counting algorithm (Hacker's Delight popcount). Each column's +sum is a binary number; high bits carry to the next column. + +- Formula size: ~2x larger than Dadda/baseline +- BUT: ~30x faster with CaDiCaL at BW=11 +- The extra auxiliary variables from popcount create balanced trees that + enable better unit propagation +- Key insight: formula size is anti-correlated with performance here + +### 5. Radix-8 (higher radix partial products) +Pre-computes x*0 through x*7 to reduce the number of partial products by +3x. Can be combined with any reduction scheme. + +- Helps at larger bitwidths when combined with Dadda +- Overhead of pre-computation hurts at small bitwidths + +### 6. Karatsuba +Divide-and-conquer: splits n-bit multiplication into three (n/2)-bit +multiplications. Asymptotically O(n^1.585) vs O(n²). + +- Implemented in the PR but not benchmarked yet + +### 7. Toom-Cook +Generalization of Karatsuba using polynomial interpolation. Splits into +more pieces for better asymptotic complexity. + +- Implemented in the PR, based on the Brain paper's approach +- Allows incremental tightening of over-approximation + +### 8. Schönhage-Strassen +Based on the number-theoretic transform. O(n log n log log n) complexity. + +- Implemented in the PR but has a bug: `bvt b = a;` in the last commit + makes it compute a*a instead of a*b +- Needs fixing and benchmarking + +## Benchmark Results + +All results use `__CPROVER_bitvector[N]` with `--no-standard-checks`. +Timeout: 300s. Commutativity uses intermediate variables. + +### Commutativity: `c=a*b; d=b*a; assert(c==d)` + +| Variant | BW=9 | BW=11 | BW=13 | BW=15 | BW=17 | BW=19 | +|---------|------|-------|-------|-------|-------|-------| +| baseline+CaDiCaL | 2.01s | 60.8s | >300s | >300s | >300s | >300s | +| **comba+CaDiCaL** | **0.26s** | **1.75s** | **8.08s** | **10.4s** | **46.4s** | **177s** | +| dadda+CaDiCaL | 0.54s | 14.1s | 180s | >300s | >300s | >300s | +| baseline+MiniSat | 5.35s | 251s | >300s | >300s | >300s | >300s | +| comba+MiniSat | 6.12s | 272s | >300s | >300s | >300s | >300s | +| dadda+MiniSat | 12.1s | >300s | >300s | >300s | >300s | >300s | +| **Bitwuzla** | **0.03s** | **0.03s** | **0.02s** | **0.02s** | **0.03s** | **0.03s** | +| **Z3** | 0.04s | 0.04s | 0.04s | 0.04s | 0.04s | 0.04s | + +### Primality: prove n has no factors (wide multiplication) + +| Variant | BW=20 | BW=24 | BW=28 | BW=30 | BW=32 | BW=34 | +|---------|-------|-------|-------|-------|-------|-------| +| **baseline+CaDiCaL** | **0.11s** | **0.42s** | **1.46s** | **3.33s** | **6.38s** | **13.8s** | +| comba+CaDiCaL | 0.16s | 0.72s | 3.09s | 5.49s | 9.76s | 17.3s | +| dadda+CaDiCaL | 0.11s | 0.63s | 2.12s | 5.02s | 9.71s | 17.0s | +| baseline+MiniSat | 0.07s | 0.26s | 1.63s | 3.66s | 7.70s | 33.2s | +| comba+MiniSat | 0.10s | 0.40s | 2.01s | 5.26s | 40.2s | 83.2s | +| dadda+MiniSat | 0.07s | 0.27s | 1.35s | 4.81s | 9.25s | 28.5s | +| **Bitwuzla** | 0.22s | 0.68s | **1.51s** | **2.94s** | 7.36s | **13.9s** | +| Z3 | 0.21s | 0.59s | 3.56s | 6.68s | 15.9s | 39.6s | + +### AWS real-world proofs + +| Encoding | aws_mul_checked | aws_list_back | aws_mul_sat | +|----------|-----------------|---------------|-------------| +| baseline+CaDiCaL | 2.53s | 6.10s | 1.24s | +| comba+CaDiCaL | 39.4s | 8.45s | 8.84s | +| **dadda+CaDiCaL** | **1.29s** | **5.75s** | **1.20s** | +| **wallace+CaDiCaL** | **1.19s** | 6.45s | 1.27s | +| radix8+CaDiCaL | 6.97s | 4.62s | 5.06s | +| karatsuba+CaDiCaL | 3.89s | >120s | 3.25s | +| toom-cook+CaDiCaL | >120s | >120s | >120s | +| schönhage+CaDiCaL | >120s | >120s | >120s | +| **Bitwuzla** | **0.83s** | **1.63s** | **1.52s** | +| Z3 | 1.05s | 1.65s | 7.93s | +| CVC5 | 1.10s | 4.93s | 2.79s | + +### All encodings on all benchmark types (CaDiCaL) + +| Encoding | comm 9 | comm 13 | prime 20 | prime 28 | aws_mul | aws_back | aws_sat | +|----------|--------|---------|----------|----------|---------|----------|---------| +| baseline | 2.01s | >300s | 0.11s | 1.46s | 2.53s | 6.10s | 1.24s | +| **comba** | **0.26s** | **8.08s** | 0.16s | 3.09s | 39.4s | 8.45s | 8.84s | +| dadda | 0.54s | 180s | 0.11s | 2.12s | **1.29s** | 5.75s | 1.20s | +| wallace | 1.89s | >300s | 0.14s | 1.87s | **1.19s** | 6.45s | 1.27s | +| radix8 | 1.37s | >300s | 0.14s | 1.87s | 6.97s | **4.62s** | 5.06s | +| radix8+dadda | 1.36s | >300s | 0.14s | 1.87s | 7.03s | **4.62s** | 5.10s | +| karatsuba | 5.54s | 10.0s | 0.88s | 69.7s | 3.89s | >120s | 3.25s | +| toom-cook | 2.48s | 31.4s | 9.24s | 71.7s | >120s | >120s | >120s | +| schönhage | 86.4s | >300s | >120s | >120s | >120s | >120s | >120s | + +### Refinement modes (comba+CaDiCaL, --refine-arithmetic) + +| Mode | comm 9 | comm 13 | prime 20 | prime 28 | aws_mul | aws_back | +|------|--------|---------|----------|----------|---------|----------| +| 0 (original) | 0.64s | 17.9s | 0.28s | 7.30s | 13.1s | 21.2s | +| 1 (assumption) | 0.67s | 15.7s | 0.32s | 7.36s | 13.1s | 14.5s | +| 2 (karatsuba) | 0.84s | 13.2s | 1.16s | 53.1s | 14.3s | 25.8s | +| 3 (toom-cook) | 0.74s | 25.9s | 0.50s | 6.93s | 15.2s | 12.3s | + +### Summary + +1. **No single encoding wins everywhere.** Comba dominates commutativity + (8x faster than baseline at BW=13), Dadda/Wallace dominate AWS proofs + (30x faster than Comba on aws_mul_checked), Baseline is best for + primality. + +2. **CaDiCaL consistently outperforms MiniSat** on commutativity (10-50x). + On primality, MiniSat is competitive at small BW but falls behind at + BW≥32. + +3. **Bitwuzla dominates** on commutativity (instant at any BW) and AWS + proofs (0.83s vs best CBMC 1.19s). On primality it's competitive with + baseline CaDiCaL. + +4. **Refinement modes don't help** on primality or AWS proofs. Mode 2 + (Karatsuba) helps slightly on commutativity but hurts on primality. + +5. **Schönhage-Strassen and Toom-Cook are too expensive** at BW≤40. + +6. **Growth rates**: commutativity (comba+CaDiCaL) ~4x per 2 bits; + primality (baseline+CaDiCaL) ~2x per 2 bits. + +## Key Insights + +### 1. Formula size is NOT the right optimization target +Comba produces 2x more clauses than Dadda but is 30x faster. The extra +auxiliary variables from popcount create a structure that enables better +unit propagation. This aligns with the Brain et al. "Automatic Generation +of Propagation Complete SAT Encodings" paper. + +### 2. Auxiliary variables that "summarize" groups of bits help propagation +The popcount algorithm creates intermediate variables that represent the +count of 1-bits in groups of 2, 4, 8, etc. These act as abstractions that +the SAT solver can reason about at a higher level. + +### 3. Balanced tree depth matters +Replacing popcount with a simple full-adder tree (same clause count as +Dadda) gives 25x worse performance than popcount. The balanced tree +structure of popcount (O(log n) depth) vs the sequential chain of +full-adders (O(n) depth) is critical. + +### 4. CaDiCaL consistently outperforms MiniSat on multiplication +Likely due to better preprocessing (bounded variable elimination, +subsumption) and inprocessing techniques. + +### 5. CaDiCaL has XOR gate extraction +CaDiCaL's congruence closure module extracts XOR gates from CNF +(`congruencexor` option, enabled by default, arity limit 4). The full-adder +sum is a 3-input XOR, which is within the arity limit. This likely explains +why CaDiCaL outperforms MiniSat on multiplication — it recognizes the XOR +structure that MiniSat treats as opaque clauses. + +**Verified experimentally**: For BW=9 commutativity, CaDiCaL extracts: +- 131 XOR gates +- 138 AND gates +- 130 congruent pairs (corresponding gates in the two multiplier circuits) + +For BW=5 distributivity (which is much harder): +- 205 XOR gates extracted +- 0 congruent pairs (the three multiplier circuits have different structures) +- 44K conflicts (vs 11K for commutativity BW=9) + +The congruence closure is key for commutativity — it discovers that +corresponding gates in the two multiplier circuits are equivalent. For +distributivity, this doesn't help because the circuits have different +operands. + +Increasing the XOR arity limit from 4 to 16 had no effect — the default +is already sufficient for the 3-input XOR in full adders. + +### 6. SMT solvers use word-level reasoning, not bit-blasting +Z3 sees `(bvmul a b)` and `(bvmul b a)` at the word level and can +immediately recognize commutativity as a trivial rewrite. It never +bit-blasts algebraic identities. This means **the right approach for +algebraic properties is word-level reasoning, not better bit-blasting**. +However, for problems that genuinely require bit-level reasoning (factoring, +hardware verification), bit-blasting is necessary. + +### 7. Redundant clauses don't help much +Tested adding: +- LSB constraint: `result[0] == op0[0] AND op1[0]` +- Zero-operand constraint: if operand is 0, result is 0 +- MSB constraint: if top halves of operands are 0, top half of result is 0 + +Results were mixed — slight improvements at some bitwidths, regressions at +others. The overhead of the extra clauses offsets any propagation benefit. + +### 8. Schönhage-Strassen is too expensive for small bitwidths +The SS encoding produces 43K variables and 227K clauses for BW=7 (vs ~600 +for Comba). It's designed for very large numbers (thousands of bits) where +O(n log n log log n) beats O(n²). For our target range (7-32 bits), it's +orders of magnitude worse. + +### 9. Intermediate variables matter for benchmark design +`assert(a*b == b*a)` is 13x slower than `c=a*b; d=b*a; assert(c==d)`. +The inline version creates a combined circuit; the intermediate variable +version creates separate circuits that the solver can reason about +independently. + +### 10. `--refine-arithmetic` exists but needs stronger initial approximation +CBMC's `--refine-arithmetic` implements abstraction-refinement for +multiplication. It starts with `x*0=0` and `x*1=x` as initial constraints +and refines incrementally. For BW=9 commutativity, it takes 11 iterations +and 0.65s (vs 0.26s for direct Comba). For BW=11, it's 6.08s vs 1.78s. +The initial approximation is too weak — adding algebraic properties +(commutativity, distributivity) as initial constraints could make many +benchmarks trivial. + +### 11. Word-level commutativity simplification is highly effective +Adding a check in `simplify_inequality` that recognizes `a op b == b op a` +for commutative operators (mult, plus, bitand, bitor, bitxor) and simplifies +to `true` makes commutativity verification instant at ANY bitwidth (tested +up to 256-bit). The simplification fires during expression simplification, +before any bit-blasting occurs. This is exactly the word-level reasoning +that Z3 does. + +**Limitation**: Only works for inline expressions (`assert(a*b == b*a)`), +not through intermediate variables (`c=a*b; d=b*a; assert(c==d)`). The +latter would require value propagation or common subexpression elimination +to expose the pattern to the simplifier. + +### 12. Word-level distributivity simplification also works +Added recognition of `a * (b + c) == a * b + a * c` (and commuted +variants) in `simplify_inequality`. Expands multiplication over addition +on each side and compares with deep commutativity checking. Makes +distributivity verification instant at any bitwidth (tested up to 128-bit). + +### 13. CryptoMiniSat's native XOR doesn't help +CryptoMiniSat5 (with native XOR clause support) is 2-25x slower than +CaDiCaL on all multiplication benchmarks. CaDiCaL's congruence closure +(which discovers equivalent gates between multiplier circuits) is more +effective than CryptoMiniSat's XOR handling. + +| BW | CaDiCaL | CryptoMiniSat5 | +|----|---------|----------------| +| 7 | 0.10s | 0.36s | +| 9 | 0.26s | 6.69s | +| 11 | 1.78s | >60s | + +### 14. Optimal encoding is solver-dependent +For MiniSat2, the **Baseline (shift-add)** encoding is actually faster +than Comba (5.19s vs 6.00s at BW=9, 249s vs 273s at BW=11). MiniSat +lacks CaDiCaL's congruence closure and XOR extraction, so Comba's extra +auxiliary variables are just noise. This means the encoding choice should +ideally depend on the solver being used. + +| BW | MiniSat+Baseline | MiniSat+Comba | CaDiCaL+Comba | +|----|-----------------|--------------|--------------| +| 7 | 0.42s | 0.41s | 0.10s | +| 9 | 5.19s | 6.00s | 0.26s | +| 11 | 249s | 273s | 1.78s | + +### 15. Associativity simplification via leaf-set comparison +Flattening nested applications of associative+commutative operators and +comparing the sorted multisets of leaves handles associativity: +`(a*b)*c == a*(b*c)` both flatten to `{a, b, c}`. Instant at any bitwidth. + +### 16. Assumption-gated incremental refinement works +Implemented two-stage refinement in `--refine-arithmetic` using +solve-with-assumptions: + +**Stage 0**: Build a narrow multiplier (low 4 bits of operands, +zero-extended to full width) gated by a retractable assumption literal. + +**Stage 1+**: Drop the gate assumption and add the full multiplier. +Key fix: if `check_SAT` finds concrete values match at stage 0 but the +overall property still fails (re-entry), force the full multiplier. + +| BW | Original | Assumption-gated | Speedup | +|----|----------|-----------------|---------| +| 7 | 0.14s | 0.03s | 4.7x | +| 9 | 0.65s | 0.04s | 16x | +| 11 | 6.10s | 0.04s | 152x | + +### 17. Toom-Cook with non-deterministic coefficients implemented +Implemented the Brain paper's polynomial interpolation approach: +- Split operands into 4-bit chunks +- Create free coefficient variables d[i] for the result polynomial +- Constrain: result = sum(d[i] * 2^(i*chunk)) +- Add evaluation points incrementally (r(0), r(1), then full fallback) + +Correct on all benchmarks but slower than assumption-gated at small +bitwidths (0.21s vs 0.03s for comm BW=7). Would benefit at larger +bitwidths where sub-multiplications are much cheaper than full. + +## AMulet2 / Algebraic Approach (from Kaufmann & Biere 2023) + +The AMulet2 tool verifies multiplier circuits using Gröbner bases over Z[X]. +Key ideas: +- Each gate is modeled as a polynomial (e.g., AND: -u + vw = 0) +- The specification is checked to reduce to zero modulo the gate polynomials +- Complex final-stage adders (Kogge-Stone, etc.) are replaced with + ripple-carry adders (verified equivalent via SAT), then the simplified + circuit is verified algebraically +- XOR-based slicing reduces the problem size + +This is fundamentally different from bit-blasting — it reasons about the +algebraic structure of the circuit. The approach is specific to verifying +known circuit implementations, not general multiplication. + +## Open Questions + +1. **The polynomial view unification**: Toom-Cook/Karatsuba decompose + multiplication as polynomial evaluation/interpolation BEFORE encoding + to propositional logic. AMulet2 reasons about gate polynomials + (`AND: -u + vw = 0`) AFTER the circuit exists. These are the same + mathematical framework applied at different levels. The key unexploited + insight: if we could recognize polynomial identities (commutativity, + distributivity) at the expression level and simplify them before + creating any circuit, we'd avoid the exponential blowup entirely. + This is what Z3 does for simple cases. + +2. **Strengthening `--refine-arithmetic`**: CBMC already has + abstraction-refinement for multiplication (starts with `x*0=0` and + `x*1=x`, refines incrementally via `--refine-arithmetic`). However, + it's currently slower than direct Comba encoding (6.08s vs 1.78s at + BW=11 commutativity, 11 refinement iterations). The initial + approximation is too weak — it doesn't include algebraic properties + like commutativity. If the refinement started with stronger word-level + properties, it could potentially solve algebraic identities without + ever fully encoding the multiplier. + +3. Can we systematically identify which "summary" auxiliary variables + help propagation? The popcount's effectiveness suggests balanced-tree + summaries are key. This connects to the propagation completeness + theory from Brain et al. 2016. + +4. For distributivity, the solver needs to discover that multiplication + distributes over addition — a property requiring reasoning about the + interaction between multiplier and adder circuits. Can we add + word-level "bridge" constraints that connect these? + +5. Is there a way to encode the multiplier that makes algebraic structure + more visible to the SAT solver? E.g., encoding partial products so + that commutativity/distributivity are apparent at the clause level. + +## Recommended Next Steps (Priority Order) + +1. **Add MiniSat to the full benchmark matrix** — quick gap-fill. + +2. **Strengthen `--refine-arithmetic` with algebraic properties** — the + existing infrastructure supports abstraction-refinement. Adding + commutativity (`a*b == b*a`), distributivity (`a*(b+c) == a*b + a*c`), + and other word-level properties as initial constraints could make many + benchmarks trivial without full bit-blasting. This is the highest-impact + direction and builds on existing CBMC infrastructure. + +3. **Try CryptoMiniSat** on the benchmarks — it has native XOR support + which could help since the full-adder sum is XOR. + +4. **Explore the Toom-Cook incremental approach** — connects to the + incremental symex work and could provide a systematic way to refine + multiplication approximations. + +5. **Investigate word-level preprocessing in CBMC's expression simplifier** + — extend it to recognize and simplify algebraic identities involving + multiplication before bit-blasting. + +## Literature + +### Brain 2021, "Further Steps Down The Wrong Path" +- Multiplication by constant: use contiguous-1s trick (128-1 instead of + 127) and pattern sharing +- Toom-Cook polynomial interpolation: allows incremental approximation +- Key result: propagation-complete multiplier is likely exponential in size +- Suggests algebraic techniques (Gröbner bases) are more promising + +### Brain et al. 2016, "Automatic Generation of Propagation Complete SAT Encodings" +- Formalizes propagation completeness using abstract satisfaction +- Algorithm to generate PCEs automatically +- Key finding: carry bits are the critical auxiliary variables for addition +- PC full-adder: 14 clauses (used by CBMC) +- PC 2x2 multiplier: 19 clauses +- Composition of PC primitives is PC for adders but NOT for multipliers + +## TODO + +- [ ] Implement proper Toom-Cook incremental refinement using + over-approximation constraints (polynomial evaluation points with + non-deterministic coefficients, not bit-slicing) +- [ ] Study whether Comba advantage holds on real-world CBMC benchmarks +- [ ] Make encoding choice solver-dependent (Comba for CaDiCaL, Baseline + for MiniSat) +- [ ] Profile CaDiCaL to understand where time is spent +- [ ] Investigate square identity simplification (requires polynomial + normalization) + +## Real-World Benchmark Sources + +1. **aws-c-common proofs** (used by CBMC's perf-benchcomp CI): Array + operations involve multiplication for index computation + (`count * element_size`). The `aws_mul_size_checked` proof directly + verifies a multiplication function. + +2. **Floating-point verification**: FP multiplication involves integer + multiplication of mantissas. CBMC's FP support uses this extensively. + +3. **SMT-LIB QF_BV benchmarks** (fmv.jku.at/smtbench/): Standard + benchmark suite for bit-vector solvers, includes multiplication-heavy + problems from hardware verification. + +4. **Cryptographic code**: Hash functions (SHA, MD5) and ciphers (AES) + use multiplication in GF(2^n). These are a natural source of hard + multiplication problems. + +Note: Most real-world CBMC usage involves multiplication by constants +(array indexing, struct field offsets) which is already handled efficiently +by constant propagation. Symbolic × symbolic multiplication is rarer but +occurs in checked arithmetic, cryptography, and FP verification. + +## Bitwuzla Deep Study + +### How Bitwuzla solves multiplication problems instantly + +Bitwuzla uses a multi-stage pipeline that avoids bit-blasting for most +algebraic identities: + +1. **Rewriter** (`NORMALIZE_COMM`): Sorts operands of commutative operators + by node ID. After this, `a*b` and `b*a` are the same expression. + Then `EQUAL_TRUE` recognizes `x == x` → `true`. + +2. **Rewriter** (`NORM_FACT_BV_ADD_MUL`): Factorizes `a*b + a*c` into + `a*(b+c)` by finding common factors. This handles distributivity. + +3. **Preprocessing** (`normalize_comm_assoc`): Flattens nested + additions/multiplications, computes occurrence maps, factors out + common subterms, and normalizes both sides of equalities to a + canonical form. This handles associativity and complex identities. + +4. **Bit-blasting**: Only invoked if the above steps don't resolve the + formula. Uses AIG-based bit-blasting (more compact than direct CNF) + with CaDiCaL as the SAT backend. + +### Key commit: c0184571 (March 21, 2025) + +Author: Mathias Preiner. Added 8 normalization rewrite rules: +- `NORM_FACT_BV_ADD_MUL`: factorize `a*b + a*c → a*(b+c)` +- `NORM_FACT_BV_ADD_SHL`: factorize additions involving shifts +- `NORM_FACT_BV_SHL_MUL` / `NORM_FACT_BV_MUL_SHL`: shift/multiply +- `NORM_BV_EXTRACT_ADD_MUL_REV*`: extract over add/mul +- `NORM_BV_MUL_POW2_REV`: multiply by power of 2 + +This single commit (434 lines) is what makes Bitwuzla solve all our +algebraic identity benchmarks instantly. + +### Normalization preprocessing pass (January 2023 onwards) + +The `PassNormalize` preprocessing pass (`normalize.cpp`) implements: +- Flattening of nested BV_ADD and BV_MUL +- Occurrence counting for common subterm factoring +- Canonical ordering of terms +- Score-based AIG complexity estimation + +This is essentially the polynomial normalization we identified as the +"right approach" — Bitwuzla implements it as a preprocessing pass. + +### What CBMC could learn from Bitwuzla + +1. **Operand normalization** (sort commutative operands by ID) — we + implemented this in `simplify_inequality` but Bitwuzla does it at + the expression construction level, making it universal. + +2. **Factorization rewrite** (`a*b + a*c → a*(b+c)`) — we implemented + this in `simplify_inequality` but Bitwuzla applies it as a general + rewrite rule, not just in equality contexts. + +3. **Polynomial normalization** (flatten + occurrence counting + common + subterm factoring) — this is the key missing piece in CBMC. It would + handle associativity, the square identity, and other complex + algebraic properties that our current simplifier cannot. + +4. **AIG-based bit-blasting** — Bitwuzla uses And-Inverter Graphs as + an intermediate representation before CNF conversion. This allows + structural hashing and simplification that direct CNF generation + misses. diff --git a/regression/cbmc-incr-oneloop/multiply-correctness-refine/main.c b/regression/cbmc-incr-oneloop/multiply-correctness-refine/main.c new file mode 100644 index 00000000000..f6f159169f0 --- /dev/null +++ b/regression/cbmc-incr-oneloop/multiply-correctness-refine/main.c @@ -0,0 +1,63 @@ +// Multiplication correctness tests +// Tests both SAT (factors exist) and UNSAT (prime) cases, +// narrow and wide multiplication, and algebraic properties. + +void test_factor_narrow(void) +{ + // 3 * 5 = 15, factors exist + __CPROVER_bitvector[8] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_assert(p * q != 15, "narrow factor: 15 = 3*5"); +} + +void test_prime_narrow(void) +{ + // 13 is prime, no factors + __CPROVER_bitvector[8] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_bitvector[16] wp = p, wq = q; + __CPROVER_assert(wp * wq != 13, "narrow prime: 13"); +} + +void test_factor_wide(void) +{ + // 15 * 69905 = 1048575, factors exist (wide multiplication) + __CPROVER_bitvector[20] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_bitvector[40] wp = p, wq = q; + __CPROVER_assert(wp * wq != 1048575ULL, "wide factor: 1048575 = 15*69905"); +} + +void test_prime_wide(void) +{ + // 1048573 is prime, no factors (wide multiplication) + __CPROVER_bitvector[20] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_bitvector[40] wp = p, wq = q; + __CPROVER_assert(wp * wq != 1048573ULL, "wide prime: 1048573"); +} + +void test_commutativity(void) +{ + __CPROVER_bitvector[8] a, b; + __CPROVER_bitvector[8] c = a * b; + __CPROVER_bitvector[8] d = b * a; + __CPROVER_assert(c == d, "commutativity"); +} + +void test_not_equal(void) +{ + // a*b != a+b in general + __CPROVER_bitvector[8] a, b; + __CPROVER_assert(a * b == a + b, "mul != add"); +} + +int main(void) +{ + test_factor_narrow(); + test_prime_narrow(); + test_factor_wide(); + test_prime_wide(); + test_commutativity(); + test_not_equal(); +} diff --git a/regression/cbmc-incr-oneloop/multiply-correctness-refine/test.desc b/regression/cbmc-incr-oneloop/multiply-correctness-refine/test.desc new file mode 100644 index 00000000000..86c1f4a5658 --- /dev/null +++ b/regression/cbmc-incr-oneloop/multiply-correctness-refine/test.desc @@ -0,0 +1,17 @@ +CORE +main.c +--no-standard-checks --refine-arithmetic +\[test_factor_narrow\.assertion\.1\].*narrow factor.*FAILURE +\[test_prime_narrow\.assertion\.1\].*narrow prime.*SUCCESS +\[test_factor_wide\.assertion\.1\].*wide factor.*FAILURE +\[test_prime_wide\.assertion\.1\].*wide prime.*SUCCESS +\[test_commutativity\.assertion\.1\].*commutativity.*SUCCESS +\[test_not_equal\.assertion\.1\].*mul != add.*FAILURE +^EXIT=10$ +^SIGNAL=0$ +^VERIFICATION FAILED$ +-- +^warning: ignoring +-- +Same as multiply-correctness but with --refine-arithmetic to catch +soundness issues in the refinement loop. diff --git a/regression/cbmc-incr-oneloop/multiply-correctness/main.c b/regression/cbmc-incr-oneloop/multiply-correctness/main.c new file mode 100644 index 00000000000..f6f159169f0 --- /dev/null +++ b/regression/cbmc-incr-oneloop/multiply-correctness/main.c @@ -0,0 +1,63 @@ +// Multiplication correctness tests +// Tests both SAT (factors exist) and UNSAT (prime) cases, +// narrow and wide multiplication, and algebraic properties. + +void test_factor_narrow(void) +{ + // 3 * 5 = 15, factors exist + __CPROVER_bitvector[8] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_assert(p * q != 15, "narrow factor: 15 = 3*5"); +} + +void test_prime_narrow(void) +{ + // 13 is prime, no factors + __CPROVER_bitvector[8] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_bitvector[16] wp = p, wq = q; + __CPROVER_assert(wp * wq != 13, "narrow prime: 13"); +} + +void test_factor_wide(void) +{ + // 15 * 69905 = 1048575, factors exist (wide multiplication) + __CPROVER_bitvector[20] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_bitvector[40] wp = p, wq = q; + __CPROVER_assert(wp * wq != 1048575ULL, "wide factor: 1048575 = 15*69905"); +} + +void test_prime_wide(void) +{ + // 1048573 is prime, no factors (wide multiplication) + __CPROVER_bitvector[20] p, q; + __CPROVER_assume(p > 1 && q > 1); + __CPROVER_bitvector[40] wp = p, wq = q; + __CPROVER_assert(wp * wq != 1048573ULL, "wide prime: 1048573"); +} + +void test_commutativity(void) +{ + __CPROVER_bitvector[8] a, b; + __CPROVER_bitvector[8] c = a * b; + __CPROVER_bitvector[8] d = b * a; + __CPROVER_assert(c == d, "commutativity"); +} + +void test_not_equal(void) +{ + // a*b != a+b in general + __CPROVER_bitvector[8] a, b; + __CPROVER_assert(a * b == a + b, "mul != add"); +} + +int main(void) +{ + test_factor_narrow(); + test_prime_narrow(); + test_factor_wide(); + test_prime_wide(); + test_commutativity(); + test_not_equal(); +} diff --git a/regression/cbmc-incr-oneloop/multiply-correctness/test.desc b/regression/cbmc-incr-oneloop/multiply-correctness/test.desc new file mode 100644 index 00000000000..e010f3ce02f --- /dev/null +++ b/regression/cbmc-incr-oneloop/multiply-correctness/test.desc @@ -0,0 +1,18 @@ +CORE +main.c +--no-standard-checks +\[test_factor_narrow\.assertion\.1\].*narrow factor.*FAILURE +\[test_prime_narrow\.assertion\.1\].*narrow prime.*SUCCESS +\[test_factor_wide\.assertion\.1\].*wide factor.*FAILURE +\[test_prime_wide\.assertion\.1\].*wide prime.*SUCCESS +\[test_commutativity\.assertion\.1\].*commutativity.*SUCCESS +\[test_not_equal\.assertion\.1\].*mul != add.*FAILURE +^EXIT=10$ +^SIGNAL=0$ +^VERIFICATION FAILED$ +-- +^warning: ignoring +-- +Multiplication correctness: verifies that factoring composites finds factors, +factoring primes does not, commutativity holds, and mul != add. +Tests both narrow (same-width) and wide (double-width) multiplication. diff --git a/scripts/bench_multiplication.sh b/scripts/bench_multiplication.sh new file mode 100755 index 00000000000..cd88cf14487 --- /dev/null +++ b/scripts/bench_multiplication.sh @@ -0,0 +1,470 @@ +#!/bin/bash +# Comprehensive multiplication encoding benchmark suite for CBMC +# +# Builds multiple CBMC variants (encoding × SAT solver), then runs ALL +# benchmarks across ALL variants, solver backends, and refinement modes. +# +# Usage: +# ./scripts/bench_multiplication.sh # full run (~hours) +# ./scripts/bench_multiplication.sh --quick # reduced (~15 min) +# ./scripts/bench_multiplication.sh --download-only # just fetch benchmarks +# ./scripts/bench_multiplication.sh --skip-build # reuse existing builds + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +BENCH_DIR="$ROOT_DIR/bench-multiplication" +RESULTS_DIR="$ROOT_DIR/bench-results" +BUILD_BASE="$ROOT_DIR/build-bench" +BV_UTILS="$ROOT_DIR/src/solvers/flattening/bv_utils.cpp" +REFINE_SRC="$ROOT_DIR/src/solvers/refinement/refine_arithmetic.cpp" +TIMEOUT=120 +RUNS=3 + +QUICK=false; DOWNLOAD_ONLY=false; SKIP_BUILD=false +for arg in "$@"; do + case $arg in + --quick) QUICK=true; TIMEOUT=30; RUNS=1 ;; + --download-only) DOWNLOAD_ONLY=true ;; + --skip-build) SKIP_BUILD=true ;; + esac +done + +# ============================================================ +# Encoding variants (compile-time flags) +# ============================================================ +# Format: "name:flags" where flags are space-separated +ENCODINGS=( + "baseline:-DNO_COMBA" + "comba:" + "dadda:-DNO_COMBA -DDADDA_TREE" + "wallace:-DNO_COMBA -DWALLACE_TREE" + "radix4:-DNO_COMBA -DRADIX_MULTIPLIER=4" + "radix8:-DNO_COMBA -DRADIX_MULTIPLIER=8" + "radix8-dadda:-DNO_COMBA -DRADIX_MULTIPLIER=8 -DDADDA_TREE" + "radix8-comba:-DRADIX_MULTIPLIER=8" + "karatsuba:-DNO_COMBA -DUSE_KARATSUBA" + "toom-cook:-DNO_COMBA -DUSE_TOOM_COOK" +) +if $QUICK; then + ENCODINGS=("baseline:-DNO_COMBA" "comba:" "dadda:-DNO_COMBA -DDADDA_TREE") +fi + +SAT_SOLVERS=("cadical" "minisat2") +$QUICK && SAT_SOLVERS=("cadical") + +# ============================================================ +# Benchmark creation +# ============================================================ +create_benchmarks() { + mkdir -p "$BENCH_DIR" "$BENCH_DIR/smt-comp" + + # Synthetic benchmarks (same as before, abbreviated for space) + for name in comm distrib assoc const3 factor square mul_overflow mul_shift \ + mul_negation mul_double mul_bounds mul_zero_factor mul_square_nonneg; do + [ -f "$BENCH_DIR/${name}.c" ] && continue + echo " Creating $name.c" + done + + # Create all synthetic benchmarks + cat > "$BENCH_DIR/comm.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] a,b,c=a*b,d=b*a; __CPROVER_assert(c==d,"comm"); } +EOF + cat > "$BENCH_DIR/distrib.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] a,b,c,l=a*(b+c),r=a*b+a*c; __CPROVER_assert(l==r,"distrib"); } +EOF + cat > "$BENCH_DIR/assoc.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] a,b,c,ab=a*b,bc=b*c; __CPROVER_assert(ab*c==a*bc,"assoc"); } +EOF + cat > "$BENCH_DIR/const3.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] a; __CPROVER_assert(a*3==a+a+a,"const3"); } +EOF + cat > "$BENCH_DIR/factor.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] p,q; __CPROVER_assume(p>1&&q>1); __CPROVER_assert(p*q!=143,"factor"); } +EOF + cat > "$BENCH_DIR/square.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] a,b,s=a+b,l=s*s,r=a*a+(__CPROVER_bitvector[BW])2*a*b+b*b; __CPROVER_assert(l==r,"square"); } +EOF + cat > "$BENCH_DIR/mul_overflow.c" << 'EOF' +#ifndef BW +#define BW 16 +#endif +int main() { __CPROVER_bitvector[BW] a,b,n=a*b; __CPROVER_bitvector[BW*2] w=(__CPROVER_bitvector[BW*2])a*(__CPROVER_bitvector[BW*2])b; __CPROVER_assert(n==(__CPROVER_bitvector[BW])w,"overflow"); } +EOF + cat > "$BENCH_DIR/mul_shift.c" << 'EOF' +#ifndef BW +#define BW 16 +#endif +int main() { __CPROVER_bitvector[BW] a; __CPROVER_assert(a*4==a<<2,"shift"); } +EOF + cat > "$BENCH_DIR/mul_negation.c" << 'EOF' +#ifndef BW +#define BW 16 +#endif +int main() { __CPROVER_bitvector[BW] a; __CPROVER_assert(a*(__CPROVER_bitvector[BW])(-1)==-a,"neg"); } +EOF + cat > "$BENCH_DIR/mul_double.c" << 'EOF' +#ifndef BW +#define BW 16 +#endif +int main() { __CPROVER_bitvector[BW] a,b; __CPROVER_bitvector[BW/2] al=a,bl=b,pl=a*b; __CPROVER_assert(pl==(__CPROVER_bitvector[BW/2])(al*bl),"double"); } +EOF + cat > "$BENCH_DIR/mul_bounds.c" << 'EOF' +#ifndef BW +#define BW 16 +#endif +int main() { __CPROVER_bitvector[BW/2] a,b; __CPROVER_bitvector[BW] w=(__CPROVER_bitvector[BW])a*(__CPROVER_bitvector[BW])b; __CPROVER_bitvector[BW*2] w2=(__CPROVER_bitvector[BW*2])a*(__CPROVER_bitvector[BW*2])b; __CPROVER_assert(w==(__CPROVER_bitvector[BW])w2,"bounds"); } +EOF + cat > "$BENCH_DIR/mul_zero_factor.c" << 'EOF' +#ifndef BW +#define BW 8 +#endif +int main() { __CPROVER_bitvector[BW] a,b,r=a*b; __CPROVER_assume(r==0); __CPROVER_assume(a!=0); __CPROVER_assume(b!=0); __CPROVER_assert(0,"zero_div"); } +EOF + cat > "$BENCH_DIR/mul_square_nonneg.c" << 'EOF' +#ifndef BW +#define BW 9 +#endif +int main() { __CPROVER_bitvector[BW] a; __CPROVER_bitvector[BW*2] w=(__CPROVER_bitvector[BW*2])a*(__CPROVER_bitvector[BW*2])a; __CPROVER_assert(w>=a,"sq_nn"); } +EOF + + # SMT-COMP benchmarks + for bw in 8 16 32; do + cat > "$BENCH_DIR/smt-comp/comm_${bw}.smt2" << SEOF +(set-logic QF_BV)(declare-fun a () (_ BitVec $bw))(declare-fun b () (_ BitVec $bw))(assert (not (= (bvmul a b) (bvmul b a))))(check-sat)(exit) +SEOF + done + for bw in 8 16; do + cat > "$BENCH_DIR/smt-comp/distrib_${bw}.smt2" << SEOF +(set-logic QF_BV)(declare-fun a () (_ BitVec $bw))(declare-fun b () (_ BitVec $bw))(declare-fun c () (_ BitVec $bw))(assert (not (= (bvmul a (bvadd b c)) (bvadd (bvmul a b) (bvmul a c)))))(check-sat)(exit) +SEOF + done + cat > "$BENCH_DIR/smt-comp/assoc_8.smt2" << 'EOF' +(set-logic QF_BV)(declare-fun a () (_ BitVec 8))(declare-fun b () (_ BitVec 8))(declare-fun c () (_ BitVec 8))(assert (not (= (bvmul (bvmul a b) c) (bvmul a (bvmul b c)))))(check-sat)(exit) +EOF + for spec in "12:143" "16:10403" "20:101101"; do + bw=${spec%%:*}; comp=${spec##*:} + cat > "$BENCH_DIR/smt-comp/factor_${bw}.smt2" << SEOF +(set-logic QF_BV)(declare-fun p () (_ BitVec $bw))(declare-fun q () (_ BitVec $bw))(assert (bvugt p (_ bv1 $bw)))(assert (bvugt q (_ bv1 $bw)))(assert (= (bvmul p q) (_ bv${comp} $bw)))(check-sat)(exit) +SEOF + done + cat > "$BENCH_DIR/smt-comp/mixed_arith_8.smt2" << 'EOF' +(set-logic QF_BV)(declare-fun a () (_ BitVec 8))(declare-fun b () (_ BitVec 8))(declare-fun c () (_ BitVec 8))(declare-fun d () (_ BitVec 8))(assert (not (= (bvadd (bvmul a b) (bvmul c d)) (bvsub (bvsub (bvmul (bvadd a c) (bvadd b d)) (bvmul a d)) (bvmul c b)))))(check-sat)(exit) +EOF + + # AWS + if [ ! -d "$BENCH_DIR/aws-c-common" ]; then + echo "Downloading aws-c-common..." + git clone --depth 1 https://github.com/awslabs/aws-c-common "$BENCH_DIR/aws-c-common" 2>&1 | tail -1 + fi + + # Bitwuzla + if ! which bitwuzla >/dev/null 2>&1; then + echo "Downloading Bitwuzla..." + curl -sL -o /tmp/bitwuzla.zip "https://github.com/bitwuzla/bitwuzla/releases/download/0.9.0/Bitwuzla-Linux-x86_64-static.zip" + unzip -o /tmp/bitwuzla.zip -d /tmp/bitwuzla-extract >/dev/null 2>&1 + sudo cp /tmp/bitwuzla-extract/Bitwuzla-Linux-x86_64-static/bin/bitwuzla /usr/local/bin/ 2>/dev/null || true + fi +} + +# ============================================================ +# Build variants +# ============================================================ +build_variant() { + local name=$1 sat_impl=$2 cxx_flags=$3 + local build_dir="$BUILD_BASE/$name" + local base_dir="$BUILD_BASE/base-${sat_impl}" + + if [ -x "$build_dir/bin/cbmc" ] && $SKIP_BUILD; then + echo " $name: reusing" + return 0 + fi + + # Ensure base build exists + if [ ! -x "$base_dir/bin/cbmc" ]; then + echo -n " Building base-${sat_impl}... " + cmake -S "$ROOT_DIR" -B"$base_dir" -Dsat_impl="$sat_impl" \ + -DCMAKE_BUILD_TYPE=Release >/dev/null 2>&1 + cmake --build "$base_dir" --target cbmc smt2_solver -- -j$(nproc) >/dev/null 2>&1 \ + && echo "OK" || { echo "FAILED"; return 1; } + fi + + # If no extra flags, just symlink to base + if [ -z "$cxx_flags" ]; then + if [ "$build_dir" != "$base_dir" ]; then + rm -rf "$build_dir" + ln -sf "$(basename "$base_dir")" "$build_dir" + fi + echo " $name: = base-${sat_impl}" + return 0 + fi + + # Copy base build and recompile only changed files + if [ ! -d "$build_dir" ]; then + cp -al "$base_dir" "$build_dir" 2>/dev/null || cp -r "$base_dir" "$build_dir" + fi + + echo -n " $name: recompile+relink... " + + # Extract compile commands from base build + local bv_obj="src/solvers/CMakeFiles/solvers.dir/flattening/bv_utils.cpp.o" + local ref_obj="src/solvers/CMakeFiles/solvers.dir/refinement/refine_arithmetic.cpp.o" + + # Get the compile command from the base build's ninja file and add our flags + local bv_cmd=$(cd "$base_dir" && ninja -t commands "$bv_obj" 2>/dev/null | head -1) + local ref_cmd=$(cd "$base_dir" && ninja -t commands "$ref_obj" 2>/dev/null | head -1) + + if [ -n "$bv_cmd" ]; then + # Add our flags and redirect output to variant build dir + (cd "$build_dir" && eval "${bv_cmd} ${cxx_flags}") >/dev/null 2>&1 + (cd "$build_dir" && eval "${ref_cmd} ${cxx_flags}") >/dev/null 2>&1 + + # Relink solvers library and cbmc + cmake --build "$build_dir" --target cbmc smt2_solver -- -j$(nproc) >/dev/null 2>&1 \ + && echo "OK" || echo "FAILED" + else + # Fallback: full cmake build with flags + cmake -S "$ROOT_DIR" -B"$build_dir" -Dsat_impl="$sat_impl" \ + -DCMAKE_CXX_FLAGS="$cxx_flags" \ + -DCMAKE_BUILD_TYPE=Release >/dev/null 2>&1 + cmake --build "$build_dir" --target cbmc smt2_solver -- -j$(nproc) >/dev/null 2>&1 \ + && echo "OK" || echo "FAILED" + fi +} + +build_all() { + mkdir -p "$BUILD_BASE" + echo "=== Building CBMC variants ===" + for enc_spec in "${ENCODINGS[@]}"; do + local enc_name=${enc_spec%%:*} + local enc_flags=${enc_spec##*:} + for sat in "${SAT_SOLVERS[@]}"; do + build_variant "${enc_name}-${sat}" "$sat" "$enc_flags" + done + done + + # Build refine-arithmetic variants (comba encoding, both SAT solvers) + for refine_mode in 0 1 2 3; do + for sat in "${SAT_SOLVERS[@]}"; do + build_variant "comba-${sat}-refine${refine_mode}" "$sat" \ + "-DREFINE_MULT_MODE=${refine_mode}" + done + done +} + +compile_aws_proofs() { + local goto_cc="$BUILD_BASE/comba-cadical/bin/goto-cc" + [ ! -x "$goto_cc" ] && goto_cc="$BUILD_BASE/baseline-cadical/bin/goto-cc" + local outdir="$BENCH_DIR/aws-goto" + [ -f "$outdir/aws_mul_size_checked.gb" ] && return 0 + mkdir -p "$outdir" + local AWS="$BENCH_DIR/aws-c-common" + local INC="-I $AWS/include -I $AWS/verification/cbmc/include" + local DEF="-DCBMC -DCBMC_OBJECT_BITS=8 -DMAX_ITEM_SIZE=2 -DMAX_INITIAL_ITEM_ALLOCATION=9223372036854775808ULL -DMAX_BUFFER_SIZE=10" + local COM="$AWS/source/allocator.c $AWS/source/common.c $AWS/source/error.c $AWS/verification/cbmc/sources/make_common_data_structures.c $AWS/verification/cbmc/sources/utils.c" + for proof in aws_mul_size_checked aws_mul_size_saturating aws_add_size_checked aws_array_list_back aws_byte_buf_clean_up; do + local h="$AWS/verification/cbmc/proofs/$proof/${proof}_harness.c" + [ ! -f "$h" ] && continue + local s="$h $COM" + case $proof in aws_array_list*) s="$s $AWS/source/array_list.c";; aws_byte_buf*) s="$s $AWS/source/byte_buf.c";; esac + echo -n " $proof... " + timeout 60 "$goto_cc" $INC $DEF $s --function ${proof}_harness -o "$outdir/${proof}.gb" 2>/dev/null && echo "OK" || echo "FAILED" + done +} + +# ============================================================ +# Run benchmarks +# ============================================================ +timed_run() { + local cmd=$1 timeout_s=$2 + local total=0 ok=0 + for run in $(seq 1 $RUNS); do + local out + out=$(timeout "$timeout_s" /usr/bin/time -f 'TIME:%e' bash -c "$cmd" 2>&1) + local t=$(echo "$out" | grep '^TIME:' | cut -d: -f2) + local has_result=$(echo "$out" | grep -cE 'VERIFICATION (SUCCESSFUL|FAILED)|^(sat|unsat)$') + if [ -n "$t" ] && [ "$has_result" -gt 0 ]; then + total=$(echo "$total + $t" | bc) + ok=1 + else + total=$(echo "$total + $timeout_s" | bc) + fi + done + local avg=$(echo "scale=2; $total / $RUNS" | bc) + [ "$ok" != "1" ] && avg="${avg}!" + echo "$avg" +} + +run_all() { + mkdir -p "$RESULTS_DIR" + local csv="$RESULTS_DIR/results_$(date +%Y%m%d_%H%M%S).csv" + echo "category,variant,benchmark,param,time_s" > "$csv" + echo "=== Running benchmarks (timeout=${TIMEOUT}s, runs=$RUNS) ===" + + # Bitwidth specs per benchmark + declare -A BW_SPECS + BW_SPECS[comm]="7 9 11 13 15" + BW_SPECS[distrib]="3 4 5 6 7" + BW_SPECS[assoc]="3 5 7 9" + BW_SPECS[const3]="8 16 32" + BW_SPECS[factor]="8 12 16" + BW_SPECS[square]="5 7 9 11" + BW_SPECS[mul_overflow]="8 12 16" + BW_SPECS[mul_shift]="8 16 32" + BW_SPECS[mul_negation]="8 16 32" + BW_SPECS[mul_double]="8 16 32" + BW_SPECS[mul_bounds]="8 16 32" + BW_SPECS[mul_zero_factor]="4 8 16" + BW_SPECS[mul_square_nonneg]="8 16 32" + if $QUICK; then + BW_SPECS[comm]="7 9 11" + BW_SPECS[distrib]="4 5" + BW_SPECS[assoc]="5 7" + BW_SPECS[factor]="8 12" + BW_SPECS[square]="5 7" + for k in mul_overflow mul_shift mul_negation mul_double mul_bounds mul_zero_factor mul_square_nonneg; do + BW_SPECS[$k]="8 16" + done + BW_SPECS[const3]="8 16" + fi + + # --- 1. Encoding × SAT solver variants on ALL synthetic benchmarks --- + for variant_dir in "$BUILD_BASE"/*/; do + local variant=$(basename "$variant_dir") + local cbmc="$variant_dir/bin/cbmc" + [ ! -x "$cbmc" ] && continue + [[ "$variant" == *refine* ]] && continue # handle separately + + echo "" + echo "--- $variant ---" + for bench in "${!BW_SPECS[@]}"; do + for bw in ${BW_SPECS[$bench]}; do + local t=$(timed_run "'$cbmc' '$BENCH_DIR/${bench}.c' -DBW=$bw --no-standard-checks --verbosity 4" "$TIMEOUT") + local line="synth,$variant,$bench,BW=$bw,${t}s" + echo " $line" + echo "$line" >> "$csv" + done + done + + # AWS proofs + for gb in "$BENCH_DIR"/aws-goto/*.gb; do + [ ! -f "$gb" ] && continue + local name=$(basename "$gb" .gb) + local t=$(timed_run "'$cbmc' '$gb' --unwind 10 --unwinding-assertions --verbosity 4" "$TIMEOUT") + local line="aws,$variant,$name,unwind=10,${t}s" + echo " $line" + echo "$line" >> "$csv" + done + done + + # --- 2. --refine-arithmetic variants (ALL benchmarks, both SAT solvers) --- + for variant_dir in "$BUILD_BASE"/comba-*-refine*/; do + local variant=$(basename "$variant_dir") + local cbmc="$variant_dir/bin/cbmc" + [ ! -x "$cbmc" ] && continue + + echo "" + echo "--- $variant + refine ---" + for bench in "${!BW_SPECS[@]}"; do + for bw in ${BW_SPECS[$bench]}; do + local t=$(timed_run "'$cbmc' '$BENCH_DIR/${bench}.c' -DBW=$bw --no-standard-checks --refine-arithmetic --verbosity 4" "$TIMEOUT") + local line="synth,${variant}+refine,$bench,BW=$bw,${t}s" + echo " $line" + echo "$line" >> "$csv" + done + done + + # AWS with refine + for gb in "$BENCH_DIR"/aws-goto/*.gb; do + [ ! -f "$gb" ] && continue + local name=$(basename "$gb" .gb) + local t=$(timed_run "'$cbmc' '$gb' --unwind 10 --unwinding-assertions --refine-arithmetic --verbosity 4" "$TIMEOUT") + local line="aws,${variant}+refine,$name,unwind=10,${t}s" + echo " $line" + echo "$line" >> "$csv" + done + done + + # --- 3. SMT solvers (one CBMC build, ALL synthetic benchmarks) --- + local cbmc_smt="$BUILD_BASE/comba-cadical/bin/cbmc" + if [ -x "$cbmc_smt" ]; then + for solver_flag in "--z3" "--cvc5" "--bitwuzla"; do + local sname=${solver_flag#--} + # bitwuzla is invoked via CBMC, not as standalone binary + if [ "$sname" != "bitwuzla" ]; then + which "$sname" >/dev/null 2>&1 || continue + fi + echo "" + echo "--- $sname ---" + for bench in "${!BW_SPECS[@]}"; do + for bw in ${BW_SPECS[$bench]}; do + local t=$(timed_run "'$cbmc_smt' '$BENCH_DIR/${bench}.c' -DBW=$bw --no-standard-checks $solver_flag --verbosity 4" "$TIMEOUT") + local line="synth,$sname,$bench,BW=$bw,${t}s" + echo " $line" + echo "$line" >> "$csv" + done + done + done + fi + + # --- 4. SMT-COMP benchmarks via smt2_solver, z3, cvc5, bitwuzla --- + echo "" + echo "--- SMT-COMP benchmarks ---" + local smt2_solver="$BUILD_BASE/comba-cadical/bin/smt2_solver" + local smt_solvers="" + [ -x "$smt2_solver" ] && smt_solvers="smt2_solver:$smt2_solver" + which z3 >/dev/null 2>&1 && smt_solvers="$smt_solvers z3:z3" + which cvc5 >/dev/null 2>&1 && smt_solvers="$smt_solvers cvc5:cvc5" + if which bitwuzla >/dev/null 2>&1; then + smt_solvers="$smt_solvers bitwuzla:bitwuzla bitwuzla-noabs:bitwuzla%--no-abstraction" + fi + + for solver_spec in $smt_solvers; do + local sname=${solver_spec%%:*} + local scmd=${solver_spec#*:} + scmd=${scmd//%/ } + echo " solver: $sname" + for f in "$BENCH_DIR"/smt-comp/*.smt2; do + local bname=$(basename "$f" .smt2) + local t=$(timed_run "$scmd '$f'" "$TIMEOUT") + local line="smt-comp,$sname,$bname,,${t}s" + echo " $line" + echo "$line" >> "$csv" + done + done + + echo "" + echo "=== Results: $csv ===" + echo "=== $(grep -c ',' "$csv") data points ===" +} + +# ============================================================ +# Main +# ============================================================ +create_benchmarks +if $DOWNLOAD_ONLY; then + echo "Benchmarks ready." + exit 0 +fi +build_all +compile_aws_proofs +run_all diff --git a/src/solvers/flattening/bv_utils.cpp b/src/solvers/flattening/bv_utils.cpp index cffe7e35dc2..995ee15624e 100644 --- a/src/solvers/flattening/bv_utils.cpp +++ b/src/solvers/flattening/bv_utils.cpp @@ -10,9 +10,38 @@ Author: Daniel Kroening, kroening@kroening.com #include +#include +#include #include #include +static std::string beautify(const bvt &bv) +{ + for(const auto &v : bv) + { + if(!v.is_constant()) + { + std::ostringstream oss; + oss << bv; + return oss.str(); + } + } + + std::string result; + std::size_t number = 0; + for(std::size_t i = 0; i < bv.size(); ++i) + { + if(result.size() % 5 == 4) + result = std::string(" ") + result; + result = std::string(bv[i].is_false() ? "0" : "1") + result; + + if(bv[i].is_true()) + number += 1 << i; + } + + return result + " (" + std::to_string(number) + ")"; +} + bvt bv_utilst::build_constant(const mp_integer &n, std::size_t width) { std::string n_str=integer2binary(n, width); @@ -777,13 +806,60 @@ bvt bv_utilst::dadda_tree(const std::vector &pps) return add(a, b); } +bvt bv_utilst::comba_column_wise(const std::vector &pps) +{ + PRECONDITION(!pps.empty()); + + std::vector columns(pps.front().size()); + for(const auto &pp : pps) + { + PRECONDITION(pp.size() == pps.front().size()); + for(std::size_t i = 0; i < pp.size(); ++i) + { + if(!pp[i].is_false()) + columns[i].push_back(pp[i]); + } + } + + bvt result; + result.reserve(columns.size()); + + for(std::size_t i = 0; i < columns.size(); ++i) + { + const bvt &column = columns[i]; + + if(column.empty()) + result.push_back(const_literal(false)); + else + { + bvt column_sum = popcount(column); + CHECK_RETURN(!column_sum.empty()); + result.push_back(column_sum.front()); + for(std::size_t j = 1; j < column_sum.size(); ++j) + { + if(i + j >= columns.size()) + break; + if(!column_sum[j].is_false()) + columns[i + j].push_back(column_sum[j]); + } + } + } + + return result; +} + // Wallace tree multiplier. This is disabled, as runtimes have // been observed to go up by 5%-10%, and on some models even by 20%. +#ifndef WALLACE_TREE // #define WALLACE_TREE +#endif // Dadda' reduction scheme. This yields a smaller formula size than Wallace -// trees (and also the default addition scheme), but remains disabled as it -// isn't consistently more performant either. +// trees (and also the default addition scheme), but isn't consistently more +// performant with simple partial-product generation. Only when using +// higher-radix multipliers the combination appears to perform better. +#ifndef DADDA_TREE // #define DADDA_TREE +#endif // The following examples demonstrate the performance differences (with a // time-out of 7200 seconds): @@ -917,16 +993,100 @@ bvt bv_utilst::dadda_tree(const std::vector &pps) // our multiplier that's not using a tree reduction scheme, but aren't uniformly // better either. +// Higher radix multipliers pre-compute partial products for groups of bits: +// radix-4 are groups of 2 bits, radix-8 are groups of 3 bits, and radix-16 are +// groups of 4 bits. Performance data for these variants combined with different +// (tree) reduction schemes are recorded at +// https://tinyurl.com/multiplier-comparison. The data suggests that radix-8 +// with Dadda's reduction yields the most consistent performance improvement +// while not regressing substantially in the matrix of different benchmarks and +// CaDiCaL and MiniSat2 as solvers. +#ifndef RADIX_MULTIPLIER +// #define RADIX_MULTIPLIER 8 +#endif +#ifndef USE_KARATSUBA +// #define USE_KARATSUBA +#endif +#ifndef USE_TOOM_COOK +// #define USE_TOOM_COOK +#endif +#ifndef USE_SCHOENHAGE_STRASSEN +// #define USE_SCHOENHAGE_STRASSEN +#endif +#ifdef RADIX_MULTIPLIER +# ifndef DADDA_TREE +# define DADDA_TREE +# endif +#endif +#if !defined(COMBA) && !defined(NO_COMBA) +#define COMBA +#endif + +#ifdef RADIX_MULTIPLIER +static bvt unsigned_multiply_by_3(propt &prop, const bvt &op) +{ + PRECONDITION(prop.cnf_handled_well()); + PRECONDITION(!op.empty()); + + bvt result; + result.reserve(op.size()); + + result.push_back(op[0]); + literalt prev_bit = const_literal(false); + + for(std::size_t i = 1; i < op.size(); ++i) + { + literalt sum = prop.new_variable(); + + prop.lcnf({sum, !op[i - 1], !op[i], !prev_bit}); + prop.lcnf({sum, !op[i - 1], !op[i], result.back()}); + prop.lcnf({sum, op[i - 1], op[i], !prev_bit, result.back()}); + prop.lcnf({sum, !op[i - 1], op[i], prev_bit, !result.back()}); + prop.lcnf({sum, op[i - 1], !op[i], !result.back()}); + prop.lcnf({sum, op[i - 1], !op[i], prev_bit}); + + prop.lcnf({!sum, !op[i - 1], op[i], !prev_bit}); + prop.lcnf({!sum, !op[i - 1], op[i], result.back()}); + prop.lcnf({!sum, !op[i - 1], !op[i], prev_bit, !result.back()}); + + prop.lcnf({!sum, op[i - 1], op[i], !result.back()}); + prop.lcnf({!sum, op[i - 1], op[i], prev_bit}); + prop.lcnf({!sum, op[i - 1], !op[i], !prev_bit, result.back()}); + + prop.lcnf({!sum, op[i], prev_bit, result.back()}); + prop.lcnf({!sum, op[i], !prev_bit, !result.back()}); + + result.push_back(sum); + prev_bit = op[i - 1]; + } + + return result; +} +#endif + bvt bv_utilst::unsigned_multiplier(const bvt &_op0, const bvt &_op1) { - bvt op0=_op0, op1=_op1; + PRECONDITION(!_op0.empty()); + PRECONDITION(!_op1.empty()); - if(is_constant(op1)) - std::swap(op0, op1); + if(_op1.size() == 1) + { + bvt product; + product.reserve(_op0.size()); + for(const auto &lit : _op0) + product.push_back(prop.land(lit, _op1.front())); + return product; + } - // build the usual quadratic number of partial products + // store partial products std::vector pps; - pps.reserve(op0.size()); + pps.reserve(_op0.size()); + + bvt op0 = _op0, op1 = _op1; + +#ifndef RADIX_MULTIPLIER + if(is_constant(op1)) + std::swap(op0, op1); for(std::size_t bit=0; bit times_three_opt; + auto times_three = [this, ×_three_opt, &op0]() -> const bvt & + { + if(!times_three_opt.has_value()) + { +# if 1 + if(prop.cnf_handled_well()) + times_three_opt = unsigned_multiply_by_3(prop, op0); + else +# endif + times_three_opt = add(op0, shift(op0, shiftt::SHIFT_LEFT, 1)); + } + return *times_three_opt; + }; + +# if RADIX_MULTIPLIER >= 8 + std::optional times_five_opt, times_seven_opt; + auto times_five = [this, ×_five_opt, &op0]() -> const bvt & + { + if(!times_five_opt.has_value()) + times_five_opt = add(op0, shift(op0, shiftt::SHIFT_LEFT, 2)); + return *times_five_opt; + }; + auto times_seven = + [this, ×_seven_opt, &op0, ×_three]() -> const bvt & + { + if(!times_seven_opt.has_value()) + times_seven_opt = add(times_three(), shift(op0, shiftt::SHIFT_LEFT, 2)); + return *times_seven_opt; + }; +# endif + +# if RADIX_MULTIPLIER == 16 + std::optional times_nine_opt, times_eleven_opt, times_thirteen_opt, + times_fifteen_opt; + auto times_nine = [this, ×_nine_opt, &op0]() -> const bvt & + { + if(!times_nine_opt.has_value()) + times_nine_opt = add(op0, shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_nine_opt; + }; + auto times_eleven = + [this, ×_eleven_opt, &op0, ×_three]() -> const bvt & + { + if(!times_eleven_opt.has_value()) + times_eleven_opt = add(times_three(), shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_eleven_opt; + }; + auto times_thirteen = + [this, ×_thirteen_opt, &op0, ×_five]() -> const bvt & + { + if(!times_thirteen_opt.has_value()) + times_thirteen_opt = add(times_five(), shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_thirteen_opt; + }; + auto times_fifteen = + [this, ×_fifteen_opt, &op0, ×_seven]() -> const bvt & + { + if(!times_fifteen_opt.has_value()) + times_fifteen_opt = add(times_seven(), shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_fifteen_opt; + }; +# endif + + for(std::size_t op1_idx = 0; op1_idx + RADIX_GROUP_SIZE - 1 < op1.size(); + op1_idx += RADIX_GROUP_SIZE) + { + const literalt &bit0 = op1[op1_idx]; + const literalt &bit1 = op1[op1_idx + 1]; +# if RADIX_MULTIPLIER >= 8 + const literalt &bit2 = op1[op1_idx + 2]; +# if RADIX_MULTIPLIER == 16 + const literalt &bit3 = op1[op1_idx + 3]; +# endif +# endif + bvt partial_sum; + + if( + bit0.is_constant() && bit1.is_constant() +# if RADIX_MULTIPLIER >= 8 + && bit2.is_constant() +# if RADIX_MULTIPLIER == 16 + && bit3.is_constant() +# endif +# endif + ) + { + if(bit0.is_false()) // *0 + { + if(bit1.is_false()) // *00 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *000 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0000 + continue; + else // 1000 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 3); +# else + continue; +# endif + } + else // *100 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0100 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 2); + else // 1100 + partial_sum = + shift(times_three(), shiftt::SHIFT_LEFT, op1_idx + 2); +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 2); +# endif + } +# else + continue; +# endif + } + else // *10 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *010 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0010 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 1); + else // 1010 + partial_sum = + shift(times_five(), shiftt::SHIFT_LEFT, op1_idx + 1); +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 1); +# endif + } + else // *110 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0110 + partial_sum = + shift(times_three(), shiftt::SHIFT_LEFT, op1_idx + 1); + else // 1110 + partial_sum = + shift(times_seven(), shiftt::SHIFT_LEFT, op1_idx + 1); +# else + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx + 1); +# endif + } +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 1); +# endif + } + } + else // *1 + { + if(bit1.is_false()) // *01 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *001 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0001 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx); + else // 1001 + partial_sum = shift(times_nine(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx); +# endif + } + else // *101 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0101 + partial_sum = shift(times_five(), shiftt::SHIFT_LEFT, op1_idx); + else // 1101 + partial_sum = + shift(times_thirteen(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(times_five(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx); +# endif + } + else // *11 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *011 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0011 + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx); + else // 1011 + partial_sum = shift(times_eleven(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } + else // *111 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0111 + partial_sum = shift(times_seven(), shiftt::SHIFT_LEFT, op1_idx); + else // 1111 + partial_sum = shift(times_fifteen(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(times_seven(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } +# else + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } + } + } + else + { + partial_sum = bvt(op1_idx, const_literal(false)); + for(std::size_t op0_idx = 0; op0_idx + op1_idx < op0.size(); ++op0_idx) + { +# if RADIX_MULTIPLIER == 4 + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + + // 00 + prop.lcnf({bit0, bit1, !partial_sum_bit}); + // 01 -> sum = _op0 + prop.lcnf({!bit0, bit1, !partial_sum_bit, _op0[op0_idx]}); + prop.lcnf({!bit0, bit1, partial_sum_bit, !_op0[op0_idx]}); + // 10 -> sum = (_op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, !partial_sum_bit, _op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, partial_sum_bit, !_op0[op0_idx - 1]}); + } + // 11 -> sum = times_three + prop.lcnf({!bit0, !bit1, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf({!bit0, !bit1, partial_sum_bit, !times_three()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit1, + prop.land(bit0, op0[op0_idx]), // 0x + prop.lselect( // 1x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx]))); + } +# elif RADIX_MULTIPLIER == 8 + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + + // 000 + prop.lcnf({bit0, bit1, bit2, !partial_sum_bit}); + // 001 -> sum = _op0 + prop.lcnf({!bit0, bit1, bit2, !partial_sum_bit, _op0[op0_idx]}); + prop.lcnf({!bit0, bit1, bit2, partial_sum_bit, !_op0[op0_idx]}); + // 010 -> sum = (_op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit, _op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, bit2, partial_sum_bit, !_op0[op0_idx - 1]}); + } + // 011 -> sum = times_three + prop.lcnf( + {!bit0, !bit1, bit2, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, bit2, partial_sum_bit, !times_three()[op0_idx]}); + // 100 -> sum = (_op0 << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit, _op0[op0_idx - 2]}); + prop.lcnf({bit0, bit1, !bit2, partial_sum_bit, !_op0[op0_idx - 2]}); + } + // 101 -> sum = times_five + prop.lcnf( + {!bit0, bit1, !bit2, !partial_sum_bit, times_five()[op0_idx]}); + prop.lcnf( + {!bit0, bit1, !bit2, partial_sum_bit, !times_five()[op0_idx]}); + // 110 -> sum = (times_three << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + !bit2, + !partial_sum_bit, + times_three()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + !bit2, + partial_sum_bit, + !times_three()[op0_idx - 1]}); + } + // 111 -> sum = times_seven + prop.lcnf( + {!bit0, !bit1, !bit2, !partial_sum_bit, times_seven()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, !bit2, partial_sum_bit, !times_seven()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit2, + prop.lselect( // 0* + !bit1, + prop.land(bit0, op0[op0_idx]), // 00x + prop.lselect( // 01x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx])), + prop.lselect( // 1* + !bit1, + prop.lselect( // 10x + !bit0, + op0_idx <= 1 ? const_literal(false) : op0[op0_idx - 2], + times_five()[op0_idx]), + prop.lselect( // 11x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_three()[op0_idx - 1], + times_seven()[op0_idx])))); + } +# elif RADIX_MULTIPLIER == 16 + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + + // 0000 + prop.lcnf({bit0, bit1, bit2, bit3, !partial_sum_bit}); + // 0001 -> sum = op0 + prop.lcnf({!bit0, bit1, bit2, bit3, !partial_sum_bit, op0[op0_idx]}); + prop.lcnf({!bit0, bit1, bit2, bit3, partial_sum_bit, !op0[op0_idx]}); + // 0010 -> sum = (op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, !bit1, bit2, bit3, !partial_sum_bit, op0[op0_idx - 1]}); + prop.lcnf( + {bit0, !bit1, bit2, bit3, partial_sum_bit, !op0[op0_idx - 1]}); + } + // 0011 -> sum = times_three + prop.lcnf( + {!bit0, + !bit1, + bit2, + bit3, + !partial_sum_bit, + times_three()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + bit2, + bit3, + partial_sum_bit, + !times_three()[op0_idx]}); + // 0100 -> sum = (op0 << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, bit1, !bit2, bit3, !partial_sum_bit, op0[op0_idx - 2]}); + prop.lcnf( + {bit0, bit1, !bit2, bit3, partial_sum_bit, !op0[op0_idx - 2]}); + } + // 0101 -> sum = times_five + prop.lcnf( + {!bit0, + bit1, + !bit2, + bit3, + !partial_sum_bit, + times_five()[op0_idx]}); + prop.lcnf( + {!bit0, + bit1, + !bit2, + bit3, + partial_sum_bit, + !times_five()[op0_idx]}); + // 0110 -> sum = (times_three << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + !bit2, + bit3, + !partial_sum_bit, + times_three()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + !bit2, + bit3, + partial_sum_bit, + !times_three()[op0_idx - 1]}); + } + // 0111 -> sum = times_seven + prop.lcnf( + {!bit0, + !bit1, + !bit2, + bit3, + !partial_sum_bit, + times_seven()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + !bit2, + bit3, + partial_sum_bit, + !times_seven()[op0_idx]}); + + // 1000 -> sum = (op0 << 3) + if(op0_idx == 0 || op0_idx == 1 || op0_idx == 2) + prop.lcnf({bit0, bit1, bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, bit1, bit2, !bit3, !partial_sum_bit, op0[op0_idx - 3]}); + prop.lcnf( + {bit0, bit1, bit2, !bit3, partial_sum_bit, !op0[op0_idx - 3]}); + } + // 1001 -> sum = times_nine + prop.lcnf( + {!bit0, + bit1, + bit2, + !bit3, + !partial_sum_bit, + times_nine()[op0_idx]}); + prop.lcnf( + {!bit0, + bit1, + bit2, + !bit3, + partial_sum_bit, + !times_nine()[op0_idx]}); + // 1010 -> sum = (times_five << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + bit2, + !bit3, + !partial_sum_bit, + times_five()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + bit2, + !bit3, + partial_sum_bit, + !times_five()[op0_idx - 1]}); + } + // 1011 -> sum = times_eleven + prop.lcnf( + {!bit0, + !bit1, + bit2, + !bit3, + !partial_sum_bit, + times_eleven()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + bit2, + !bit3, + partial_sum_bit, + !times_eleven()[op0_idx]}); + // 1100 -> sum = (times_three << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_three()[op0_idx - 2]}); + prop.lcnf( + {bit0, + bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_three()[op0_idx - 2]}); + } + // 1101 -> sum = times_thirteen + prop.lcnf( + {!bit0, + bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_thirteen()[op0_idx]}); + prop.lcnf( + {!bit0, + bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_thirteen()[op0_idx]}); + // 1110 -> sum = (times_seven << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_seven()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_seven()[op0_idx - 1]}); + } + // 1111 -> sum = times_fifteen + prop.lcnf( + {!bit0, + !bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_fifteen()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_fifteen()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit3, + prop.lselect( // 0* + !bit2, + prop.lselect( // 00* + !bit1, + prop.land(bit0, op0[op0_idx]), // 000x + prop.lselect( // 001x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx])), + prop.lselect( // 01* + !bit1, + prop.lselect( // 010x + !bit0, + op0_idx <= 1 ? const_literal(false) : op0[op0_idx - 2], + times_five()[op0_idx]), + prop.lselect( // 011x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_three()[op0_idx - 1], + times_seven()[op0_idx]))), + prop.lselect( // 1* + !bit2, + prop.lselect( // 10* + !bit1, + prop.lselect( // 100x + !bit0, + op0_idx <= 2 ? const_literal(false) : op0[op0_idx - 3], + times_nine()[op0_idx]), + prop.lselect( // 101x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_five()[op0_idx - 1], + times_eleven()[op0_idx])), + prop.lselect( // 11* + !bit1, + prop.lselect( // 110x + !bit0, + op0_idx <= 1 ? const_literal(false) + : times_three()[op0_idx - 2], + times_thirteen()[op0_idx]), + prop.lselect( // 111x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_seven()[op0_idx - 1], + times_fifteen()[op0_idx]))))); + } +# else +# error Unsupported radix +# endif + } + } + + pps.push_back(std::move(partial_sum)); + } + + if(op1.size() % RADIX_GROUP_SIZE == 1) + { + if(op0.size() == op1.size()) + { + if(pps.empty()) + pps.push_back(bvt(op0.size(), const_literal(false))); + + // This is the partial product of the MSB of op1 with op0, which is all + // zeros except for (possibly) the MSB. Since we don't need to account for + // any carry out of adding this partial product, we just need to compute + // the sum the MSB of one of the partial products and this partial + // product, we is an xor of just those bits. + pps.back().back() = + prop.lxor(pps.back().back(), prop.land(op0[0], op1.back())); + } + else + { + bvt partial_sum = bvt(op1.size() - 1, const_literal(false)); + for(const auto &lit : op0) + { + partial_sum.push_back(prop.land(lit, op1.back())); + if(partial_sum.size() == op0.size()) + break; + } + pps.push_back(std::move(partial_sum)); + } + } +# if RADIX_MULTIPLIER >= 8 + else if(op1.size() % RADIX_GROUP_SIZE == 2) + { + const literalt &bit0 = op1[op1.size() - 2]; + const literalt &bit1 = op1[op1.size() - 1]; + + bvt partial_sum = bvt(op1.size() - 2, const_literal(false)); + for(std::size_t op0_idx = 0; op0_idx < 2; ++op0_idx) + { + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + // 00 + prop.lcnf({bit0, bit1, !partial_sum_bit}); + // 01 -> sum = op0 + prop.lcnf({!bit0, bit1, !partial_sum_bit, op0[op0_idx]}); + prop.lcnf({!bit0, bit1, partial_sum_bit, !op0[op0_idx]}); + // 10 -> sum = (op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, !partial_sum_bit, op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, partial_sum_bit, !op0[op0_idx - 1]}); + } + // 11 -> sum = times_three + prop.lcnf({!bit0, !bit1, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf({!bit0, !bit1, partial_sum_bit, !times_three()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit1, + prop.land(bit0, op0[op0_idx]), // 0x + prop.lselect( // 1x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx]))); + } + } + + pps.push_back(std::move(partial_sum)); + } +# endif +# if RADIX_MULTIPLIER == 16 + else if(op1.size() % RADIX_GROUP_SIZE == 3) + { + const literalt &bit0 = op1[op1.size() - 3]; + const literalt &bit1 = op1[op1.size() - 2]; + const literalt &bit2 = op1[op1.size() - 1]; + + bvt partial_sum = bvt(op1.size() - 3, const_literal(false)); + for(std::size_t op0_idx = 0; op0_idx < 3; ++op0_idx) + { + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + // 000 + prop.lcnf({bit0, bit1, bit2, !partial_sum_bit}); + // 001 -> sum = op0 + prop.lcnf({!bit0, bit1, bit2, !partial_sum_bit, op0[op0_idx]}); + prop.lcnf({!bit0, bit1, bit2, partial_sum_bit, !op0[op0_idx]}); + // 010 -> sum = (op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit, op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, bit2, partial_sum_bit, !op0[op0_idx - 1]}); + } + // 011 -> sum = times_three + prop.lcnf( + {!bit0, !bit1, bit2, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, bit2, partial_sum_bit, !times_three()[op0_idx]}); + // 100 -> sum = (op0 << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit, op0[op0_idx - 2]}); + prop.lcnf({bit0, bit1, !bit2, partial_sum_bit, !op0[op0_idx - 2]}); + } + // 101 -> sum = times_five + prop.lcnf( + {!bit0, bit1, !bit2, !partial_sum_bit, times_five()[op0_idx]}); + prop.lcnf( + {!bit0, bit1, !bit2, partial_sum_bit, !times_five()[op0_idx]}); + // 110 -> sum = (times_three << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, !bit1, !bit2, !partial_sum_bit, times_three()[op0_idx - 1]}); + prop.lcnf( + {bit0, !bit1, !bit2, partial_sum_bit, !times_three()[op0_idx - 1]}); + } + // 111 -> sum = times_seven + prop.lcnf( + {!bit0, !bit1, !bit2, !partial_sum_bit, times_seven()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, !bit2, partial_sum_bit, !times_seven()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit2, + prop.lselect( // 0* + !bit1, + prop.land(bit0, op0[op0_idx]), // 00x + prop.lselect( // 01x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx])), + prop.lselect( // 1* + !bit1, + prop.lselect( // 10x + !bit0, + op0_idx <= 1 ? const_literal(false) : op0[op0_idx - 2], + times_five()[op0_idx]), + prop.lselect( // 11x + !bit0, + op0_idx == 0 ? const_literal(false) : times_three()[op0_idx - 1], + times_seven()[op0_idx])))); + } + } + + pps.push_back(std::move(partial_sum)); + } +# endif +#endif if(pps.empty()) return zeros(op0.size()); @@ -951,6 +1909,8 @@ bvt bv_utilst::unsigned_multiplier(const bvt &_op0, const bvt &_op1) return wallace_tree(pps); #elif defined(DADDA_TREE) return dadda_tree(pps); +#elif defined(COMBA) + return comba_column_wise(pps); #else bvt product = pps.front(); @@ -962,6 +1922,658 @@ bvt bv_utilst::unsigned_multiplier(const bvt &_op0, const bvt &_op1) } } +bvt bv_utilst::unsigned_karatsuba_full_multiplier( + const bvt &op0, + const bvt &op1) +{ + // We review symbolic encoding of multiplication in context of sw + // verification, bit width is 2^n, distinguish truncating (x mod 2^2^n) from + // double-output-width multiplication, truncating Karatsuba is 2 truncating + // half-width multiplication plus one double-output-width of half width, for + // double output width Karatsuba idea is challenge to avoid width extension, + // check Wikipedia edit history + + PRECONDITION(op0.size() == op1.size()); + const std::size_t op_size = op0.size(); + PRECONDITION(op_size > 0); + PRECONDITION((op_size & (op_size - 1)) == 0); + + if(op_size == 1) + return {prop.land(op0[0], op1[0]), const_literal(false)}; + + const std::size_t half_op_size = op_size >> 1; + + bvt x0{op0.begin(), op0.begin() + half_op_size}; + bvt x1{op0.begin() + half_op_size, op0.end()}; + + bvt y0{op1.begin(), op1.begin() + half_op_size}; + bvt y1{op1.begin() + half_op_size, op1.end()}; + + bvt z0 = unsigned_karatsuba_full_multiplier(x0, y0); + bvt z2 = unsigned_karatsuba_full_multiplier(x1, y1); + + bvt x0_sub = zero_extension(x0, half_op_size + 1); + bvt x1_sub = zero_extension(x1, half_op_size + 1); + + bvt y0_sub = zero_extension(y0, half_op_size + 1); + bvt y1_sub = zero_extension(y1, half_op_size + 1); + + bvt x1_minus_x0_ext = sub(x1_sub, x0_sub); + literalt x1_minus_x0_sign = sign_bit(x1_minus_x0_ext); + bvt x1_minus_x0_abs = absolute_value(x1_minus_x0_ext); + x1_minus_x0_abs.pop_back(); + bvt y0_minus_y1_ext = sub(y0_sub, y1_sub); + literalt y0_minus_y1_sign = sign_bit(y0_minus_y1_ext); + bvt y0_minus_y1_abs = absolute_value(y0_minus_y1_ext); + y0_minus_y1_abs.pop_back(); + bvt sub_mult = + unsigned_karatsuba_full_multiplier(x1_minus_x0_abs, y0_minus_y1_abs); + bvt sub_mult_ext = zero_extension(sub_mult, op_size + 1); + bvt z1_ext = add_sub( + add(zero_extension(z0, op_size + 1), zero_extension(z2, op_size + 1)), + sub_mult_ext, + prop.lxor(x1_minus_x0_sign, y0_minus_y1_sign)); + + bvt z0_full = zero_extension(z0, op_size << 1); + bvt z1_full = + zero_extension(concatenate(zeros(half_op_size), z1_ext), op_size << 1); + bvt z2_full = concatenate(zeros(op_size), z2); + + return add(add(z0_full, z1_full), z2_full); +} + +bvt bv_utilst::unsigned_karatsuba_multiplier(const bvt &_op0, const bvt &_op1) +{ + if(_op0.size() != _op1.size()) + return unsigned_multiplier(_op0, _op1); + + const std::size_t op_size = _op0.size(); + if(op_size == 1) + return {prop.land(_op0[0], _op1[0])}; + + // Make sure we work with operands the length of which are powers of two + const std::size_t log2 = address_bits(op_size); + PRECONDITION(sizeof(std::size_t) * CHAR_BIT > log2); + const std::size_t two_to_log2 = (std::size_t)1 << log2; + bvt a = zero_extension(_op0, two_to_log2); + bvt b = zero_extension(_op1, two_to_log2); + + const std::size_t half_op_size = two_to_log2 >> 1; + + // We split each of the operands in half and treat them as coefficients of a + // polynomial a * 2^half_op_size + b. Straightforward polynomial + // multiplication then yields + // a0 * a1 * 2^op_size + (a0 * b1 + a1 * b0) * 2^half_op_size + b0 * b1 + // These would be four multiplications (the operands of which have half the + // original bit width): + // z0 = b0 * b1 + // z1 = a0 * b1 + a1 * b0 + // z2 = a0 * a1 + // Karatsuba's insight is that these four multiplications can be expressed + // using just three multiplications: + // z1 = (a0 - b0) * (b1 - a1) + z0 + z2 + // + // Worked 4-bit example, 4-bit result: + // abcd * efgh -> 4-bit result + // cd * gh -> 4-bit result + // cd * ef -> 2-bit result + // ab * gh -> 2-bit result + // d * h -> 2-bit result + // c * g -> 2-bit result + // (c - d) * (h - g) + dh + cg; use an extra sign bit for each of the + // subtractions, and conditionally negate the product by xor-ing those sign + // bits; dh + cg is a 2-bit addition (with possible results 0, 1, 2); the + // product has possible values (-1, 0, 1); the final sum cannot evaluate to -1 + // as + // * c=1, d=0, h=0, g=1 (1 * -1) implies cg=1 + // * c=0, d=1, h=1, g=0 (-1 * 1) implies dh=1 + // Therefore, after adding (dh + cg) the multiplication can safely be added + // over just 2 bits. + + bvt x0{a.begin(), a.begin() + half_op_size}; + bvt x1{a.begin() + half_op_size, a.end()}; + bvt y0{b.begin(), b.begin() + half_op_size}; + bvt y1{b.begin() + half_op_size, b.end()}; + + bvt z0 = unsigned_karatsuba_full_multiplier(x0, y0); + bvt z1 = add( + unsigned_karatsuba_multiplier(x1, y0), + unsigned_karatsuba_multiplier(x0, y1)); + bvt z1_full = concatenate(zeros(half_op_size), z1); + + bvt result = add(z0, z1_full); + CHECK_RETURN(result.size() >= op_size); + if(result.size() > op_size) + result.resize(op_size); + return result; +} + +bvt bv_utilst::unsigned_toom_cook_multiplier(const bvt &_op0, const bvt &_op1) +{ + PRECONDITION(_op0.size() == _op1.size()); + PRECONDITION(!_op0.empty()); + + if(_op0.size() == 1) + return {prop.land(_op0[0], _op1[0])}; + + // break up _op0, _op1 in groups of at most GROUP_SIZE bits +#define GROUP_SIZE 8 + const std::size_t d_bits = + 2 * GROUP_SIZE + + 2 * address_bits((_op0.size() + GROUP_SIZE - 1) / GROUP_SIZE); + std::vector a, b, c_ops, d; + for(std::size_t i = 0; i < _op0.size(); i += GROUP_SIZE) + { + std::size_t u = std::min(i + GROUP_SIZE, _op0.size()); + a.emplace_back(_op0.begin() + i, _op0.begin() + u); + b.emplace_back(_op1.begin() + i, _op1.begin() + u); + + c_ops.push_back(zeros(i)); + d.push_back(prop.new_variables(d_bits)); + c_ops.back().insert(c_ops.back().end(), d.back().begin(), d.back().end()); + c_ops.back() = zero_extension(c_ops.back(), _op0.size()); + } + for(std::size_t i = a.size(); i < 2 * a.size() - 1; ++i) + { + d.push_back(prop.new_variables(d_bits)); + } + + // r(0) + bvt r_0 = d[0]; + prop.l_set_to_true(equal( + r_0, + unsigned_multiplier( + zero_extension(a[0], r_0.size()), zero_extension(b[0], r_0.size())))); + + for(std::size_t j = 1; j < a.size(); ++j) + { + // r(2^(j-1)) + bvt r_j = zero_extension( + d[0], std::min(_op0.size(), d[0].size() + (j - 1) * (d.size() - 1))); + for(std::size_t i = 1; i < d.size(); ++i) + { + r_j = add( + r_j, + shift( + zero_extension(d[i], r_j.size()), shiftt::SHIFT_LEFT, (j - 1) * i)); + } + + bvt a_even = zero_extension(a[0], r_j.size()); + for(std::size_t i = 2; i < a.size(); i += 2) + { + a_even = add( + a_even, + shift( + zero_extension(a[i], a_even.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + bvt a_odd = zero_extension(a[1], r_j.size()); + for(std::size_t i = 3; i < a.size(); i += 2) + { + a_odd = add( + a_odd, + shift( + zero_extension(a[i], a_odd.size()), + shiftt::SHIFT_LEFT, + (j - 1) * (i - 1))); + } + bvt b_even = zero_extension(b[0], r_j.size()); + for(std::size_t i = 2; i < b.size(); i += 2) + { + b_even = add( + b_even, + shift( + zero_extension(b[i], b_even.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + bvt b_odd = zero_extension(b[1], r_j.size()); + for(std::size_t i = 3; i < b.size(); i += 2) + { + b_odd = add( + b_odd, + shift( + zero_extension(b[i], b_odd.size()), + shiftt::SHIFT_LEFT, + (j - 1) * (i - 1))); + } + + prop.l_set_to_true(equal( + r_j, + unsigned_multiplier( + add(a_even, shift(a_odd, shiftt::SHIFT_LEFT, j - 1)), + add(b_even, shift(b_odd, shiftt::SHIFT_LEFT, j - 1))))); + + // r(-2^(j-1)) + bvt r_minus_j = zero_extension( + d[0], std::min(_op0.size(), d[0].size() + (j - 1) * (d.size() - 1))); + for(std::size_t i = 1; i < d.size(); ++i) + { + if(i % 2 == 1) + { + r_minus_j = sub( + r_minus_j, + shift( + zero_extension(d[i], r_minus_j.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + else + { + r_minus_j = add( + r_minus_j, + shift( + zero_extension(d[i], r_minus_j.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + } + + prop.l_set_to_true(equal( + r_minus_j, + unsigned_multiplier( + sub(a_even, shift(a_odd, shiftt::SHIFT_LEFT, j - 1)), + sub(b_even, shift(b_odd, shiftt::SHIFT_LEFT, j - 1))))); + } + + if(c_ops.empty()) + return zeros(_op0.size()); + else + { +#ifdef WALLACE_TREE + return wallace_tree(c_ops); +#elif defined(DADDA_TREE) + return dadda_tree(c_ops); +#elif defined(COMBA) + return comba_column_wise(c_ops); +#else + bvt product = c_ops.front(); + + for(auto it = std::next(c_ops.begin()); it != c_ops.end(); ++it) + product = add(product, *it); + + return product; +#endif + } +} + +bvt bv_utilst::unsigned_schoenhage_strassen_multiplier( + const bvt &a, + const bvt &b) +{ + PRECONDITION(a.size() == b.size()); + + // Running examples: we want to multiple 213 by 15 as 8- or 9-bit integers. + // That is, we seek to multiply 11010101 (011010101) by 00001111 (000001111). + // ^bit 7 ^bit 0 + // The expected result is 123 as both an 8-bit and 9-bit result (001111011). + + // We compute the result modulo a Fermat number F_m = 2^2^m + 1. The maximum + // result when multiplying a by b (with their sizes being the same per the + // precondition above) is 2^2*op_size - 1. + // TODO: we don't actually need a full multiplier, a result with up to op_size + // bits is sufficient for our purposes. + // Hence we require 2^2^m >= 2^2*op_size, i.e., 2^m >= 2*op_size, or + // m >= log_2(op_size) + 1. + // For our examples m will be 4 and 5, respectively, with Fermat numbers + // 2^16 + 1 and 2^32 + 1. + const std::size_t m = address_bits(a.size()) + 1; + std::cerr << "m: " << m << std::endl; + + // Extend bit width to 2^(m + 1) = op_size (rounded to next power of 2) * 4 + // For our examples, extended bit widths will be 32 and 64. + PRECONDITION(sizeof(std::size_t) * CHAR_BIT > m + 1); + const std::size_t two_to_m_plus_1 = (std::size_t)1 << (m + 1); + std::cerr << "a: " << beautify(a) << std::endl; + std::cerr << "b: " << beautify(b) << std::endl; + bvt a_ext = zero_extension(a, two_to_m_plus_1); + bvt b_ext = zero_extension(b, two_to_m_plus_1); + + // We need to distinguish whether m is even or odd + // m = 2n - 1 for odd m and m = 2n -2 for even m + // For our 8-bit inputs we have m = 4 and, therefore, n = 3. + // For our 9-bit inputs we have m = 5 and, therefore, n = 3. + const std::size_t n = m % 2 == 1 ? (m + 1) / 2 : m / 2 + 1; + std::cerr << "n: " << n << std::endl; + + // For even m create 2^n (of 2^(n - 1) bits) chunks from a_ext, b_ext (for our + // 8-bit inputs we have chunk_size = 4 with num_chunks = 8). + // For odd m create 2^(n + 1) chunks (of 2^(n - 1) bits) from a_ext, b_ext; + // a_0 will be bit positions 0 through to 2^(n - 1) - 1, a_{2^(n + 1) - 1} + // will be bit positions up to 2^(m + 1) - 1. + // For our 9-bit inputs we have chunk_size = 4 with num_chunks = 16 + const std::size_t chunk_size = (std::size_t)1 << (n - 1); + const std::size_t num_chunks = two_to_m_plus_1 / chunk_size; + CHECK_RETURN( + num_chunks == m % 2 ? (std::size_t)1 << (n + 1) : (std::size_t)1 << n); + std::cerr << "chunk_size: " << chunk_size << std::endl; + std::cerr << "num_chunks: " << num_chunks << std::endl; + std::cerr << "address_bits(num_chunks): " << address_bits(num_chunks) + << std::endl; + + std::vector a_rho, b_sigma; + a_rho.reserve(num_chunks); + b_sigma.reserve(num_chunks); + for(std::size_t i = 0; i < num_chunks; ++i) + { + a_rho.emplace_back( + a_ext.begin() + i * chunk_size, a_ext.begin() + (i + 1) * chunk_size); + b_sigma.emplace_back( + b_ext.begin() + i * chunk_size, b_ext.begin() + (i + 1) * chunk_size); + } + // For our example we now have + // a_rho = [ 0101, 1101, 0000, ..., 0000 ] + // b_sigma = [ 1111, 0000, 0000, ..., 0000 ] + + // Compute gamma_r = \sum_{i + j = r} a_i * b_j with bit width 3n + 5 with r + // ranging from 0 to 2^(n + 2) - 1 (to 2^(n + 1) - 1 when m is even). + // For our example this will be additions/multiplications of width 14 + // (implying that school book multiplication would be cheaper, as is the case + // for all operand lengths below 32 bits). + // TODO: all subsequent steps seem to be using mod 2^(n + 2) (mod 2^(n + 1) + // when m is even), so it may be sufficient to do this over n + 2 bits instead + // of 3n + 5. + std::vector gamma_tau{num_chunks * 2, zeros(3 * n + 5)}; + for(std::size_t tau = 0; tau < num_chunks * 2; ++tau) + { + for(std::size_t rho = tau < num_chunks ? 0 : tau - num_chunks + 1; + rho < num_chunks && rho <= tau; + ++rho) + { + const std::size_t sigma = tau - rho; + gamma_tau[tau] = add( + gamma_tau[tau], + unsigned_multiplier( + zero_extension(a_rho[rho], 3 * n + 5), + zero_extension(b_sigma[sigma], 3 * n + 5))); + } + } + // For our example we obtain + // gamma_tau = [ 00 0000 0100 1011, 00 0000 1100 0011, 0.... ] + + // Compute c_tau over bit width n + 2 (n + 1 when m is even) as gamma_tau + + // gamma_{tau + 2^(n + 1)} (gamma_{tau + 2^n} when m is even). + std::vector c_tau; + c_tau.reserve(num_chunks); + for(std::size_t tau = 0; tau < num_chunks; ++tau) + { + c_tau.push_back(add(gamma_tau[tau], gamma_tau[tau + num_chunks])); + CHECK_RETURN(c_tau.back().size() >= address_bits(num_chunks) + 1); + c_tau.back().resize(address_bits(num_chunks) + 1); + std::cerr << "c_tau[" << tau << "]: " << beautify(c_tau[tau]) << std::endl; + } + // For our example we obtain + // c_tau = [ 01011, 00011, 0... ] + + // Compute z_j = c_j - c_{j + 2^n} (mod 2^(n + 2)) (mod 2^(n + 1) and c_{j + + // 2^(n - 1)} when m is even) + std::vector z_j; + z_j.reserve(num_chunks / 2); + for(std::size_t j = 0; j < num_chunks / 2; ++j) + z_j.push_back(sub(c_tau[j], c_tau[j + num_chunks / 2])); + // For our example we have z_j = c_tau as all elements beyond the second one + // are zeros. + + // Compute z_j mod F_n using number-theoretic transform with omega = 2 for + // odd m and omega = 4 for even m. + // For our examples we have F_n = 2^2^n + 1 = 257 with 2 being a 2^(n + 1)-th + // root of unity, i.e., 2^16 \equiv 1 (mod 257) (with 4 being a 2^n-root of + // unity, i.e., 4^8 \equiv 1 (mod 257). The DFT table for omega = 2 would be + // 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + // 1 2 4 8 16 32 64 128 -1 -2 -4 -8 -16 -32 -64 -128 + // 1 4 16 64 -1 -4 -16 -64 1 4 16 64 -1 -4 -16 -64 + // 1 8 64 -2 -16 -128 4 32 -1 -8 -64 2 16 128 -4 -32 + // 1 16 -1 -16 1 16 -1 -16 1 16 -1 -16 1 16 -1 -16 + // 1 32 -4 -128 16 -2 -64 8 -1 -32 4 128 -16 2 64 -8 + // 1 64 -16 4 -1 -64 16 -2 1 64 -16 4 -1 -64 16 -4 + // 1 128 -64 32 -16 8 -2 2 -1 -128 64 -32 16 -8 4 -2 + // 1 -1 1 -1 1 -1 1 -1 1 -1 1 -1 1 -1 1 -1 + // 1 -2 4 -8 16 -32 64 -128 -1 2 -4 8 -16 32 -64 128 + // 1 -4 16 -64 -1 4 -16 64 1 -4 16 -64 -1 4 -16 64 + // 1 -8 64 2 -16 128 4 -32 -1 8 -64 -2 16 -128 -4 32 + // 1 -16 -1 16 1 -16 -1 16 1 -16 -1 16 1 -16 -1 16 + // 1 -32 -4 128 16 2 -64 -8 -1 32 4 -128 -16 -2 64 8 + // 1 -64 -16 -4 -1 64 16 4 1 -64 -16 -4 -1 64 16 4 + // 1 -128 -64 -32 -16 -8 -4 -2 -1 128 64 32 16 8 4 2 + // For fast NTT (less than O(n^2)) use Cooley-Tukey for NTT, then perform + // element-wise multiplication, and finally apply Gentleman-Sande for + // inverse NTT. + + // Addition mod F_n with overflow + auto cyclic_add = [this](const bvt &x, const bvt &y) + { + PRECONDITION(x.size() == y.size()); + + auto result_with_overflow = adder(x, y, const_literal(false)); + if(result_with_overflow.second.is_false()) + return result_with_overflow.first; + + return add( + result_with_overflow.first, + zero_extension(bvt{1, result_with_overflow.second}, x.size())); + }; + + // Compute NTT + std::vector a_j, b_j; + a_j.reserve(num_chunks); + b_j.reserve(num_chunks); + for(std::size_t j = 0; j < num_chunks; ++j) + { + // All NTT steps are mod F_n, i.e., mod 2^2^n + 1, which implies we need + // 2^(n + 1) bits to represent numbers + a_j.push_back(zero_extension(a_rho[j], (std::size_t)1 << (n + 1))); + b_j.push_back(zero_extension(b_sigma[j], (std::size_t)1 << (n + 1))); + } + // Use in-place iterative Cooley-Tukey + std::vector Aa, Ab; + Aa.reserve(num_chunks); + Ab.reserve(num_chunks); + // In the following we use k represented as bits k_{n - 1}...k_0 and + // j_0...j_{n - 1}, i.e., the most-significant bit of k is k_{n - 1} while the + // MSB for j is j_0. + for(std::size_t k = 0; k < num_chunks; ++k) + { + // reverse n (n - 1 if m is even) bits of k + std::size_t j = 0; + for(std::size_t nu = 0; nu < address_bits(num_chunks); ++nu) + { + j <<= 1; // the initial shift has no effect + j |= (k & (1 << nu)) >> nu; + } + Aa.push_back(a_j[j]); + Ab.push_back(b_j[j]); + } + for(std::size_t nu = 1; nu <= address_bits(num_chunks); ++nu) + { + const std::size_t bit_nu = (std::size_t)1 << (nu - 1); + std::size_t bits_up_to_nu = 0; + for(std::size_t i = 0; i < nu - 1; ++i) + bits_up_to_nu |= 1 << i; + + // we only need odd ones + for(std::size_t k = 1; k < num_chunks; k += 2) + { + if((k & bit_nu) == 0) + continue; + + bvt Aa_nu_bit_is_zero = Aa[k & ~bit_nu]; + bvt Ab_nu_bit_is_zero = Ab[k & ~bit_nu]; + + const std::size_t chi = (k & bits_up_to_nu) + << (address_bits(num_chunks) - 1 - (nu - 1)); + const std::size_t omega = m % 2 == 1 ? 2 : 4; + const std::size_t shift_dist = chi * omega / 2; + + if(nu > 1) // no need to update even indices + { + Aa[k & ~bit_nu] = cyclic_add( + Aa_nu_bit_is_zero, shift(Aa[k], shiftt::ROTATE_LEFT, shift_dist)); + Ab[k & ~bit_nu] = cyclic_add( + Ab_nu_bit_is_zero, shift(Ab[k], shiftt::ROTATE_LEFT, shift_dist)); + std::cerr << "Aa[" << nu << "](" << (k & ~bit_nu) + << "): " << beautify(Aa[k & ~bit_nu]) << std::endl; +#if 0 + std::cerr << "Ab[" << nu << "](" << (k & ~bit_nu) + << "): " << beautify(Ab[k & ~bit_nu]) << std::endl; +#endif + } + + // subtraction mod F_n is addition of subtrahend cyclically shifted 2^n + // positions to the left + const std::size_t shift_dist_for_sub = shift_dist + ((std::size_t)1 << n); + Aa[k] = cyclic_add( + Aa_nu_bit_is_zero, + shift(Aa[k], shiftt::ROTATE_LEFT, shift_dist_for_sub)); + Ab[k] = cyclic_add( + Ab_nu_bit_is_zero, + shift(Ab[k], shiftt::ROTATE_LEFT, shift_dist_for_sub)); + std::cerr << "Aa[" << nu << "](" << k << "): " << beautify(Aa[k]) + << std::endl; +#if 0 + std::cerr << "Ab[" << nu << "](" << k << "): " << beautify(Ab[k]) + << std::endl; +#endif + } + } + + // Either compute u - v (if u > v), else u - v + 2^2^n + 1 + auto reduce_to_mod_F_n = [this](const bvt &x) + { + const std::size_t two_to_power_of_n = x.size() / 2; + // std::cerr << "two_to_power_of_n: " << two_to_power_of_n << std::endl; + const bvt u = + zero_extension(bvt{x.begin(), x.begin() + two_to_power_of_n}, x.size()); + // std::cerr << "u: " << beautify(u) << std::endl; + const bvt v = + zero_extension(bvt{x.begin() + two_to_power_of_n, x.end()}, x.size()); + // std::cerr << "v: " << beautify(v) << std::endl; + bvt two_to_power_of_two_to_power_of_n_plus_1 = build_constant(1, x.size()); + two_to_power_of_two_to_power_of_n_plus_1[two_to_power_of_n] = + const_literal(true); + const bvt u_ext = select( + unsigned_less_than(u, v), + add(u, two_to_power_of_two_to_power_of_n_plus_1), + u); + // std::cerr << "u_ext: " << beautify(u_ext) << std::endl; + return sub(u_ext, v); + }; + + std::vector a_hat_k{num_chunks, bvt{}}, b_hat_k{num_chunks, bvt{}}; + // Reduce by F_n + for(std::size_t j = 1; j < num_chunks; j += 2) + { + a_hat_k[j] = reduce_to_mod_F_n(Aa[j]); + std::cerr << "a_hat_k[" << j << "]: " << beautify(a_hat_k[j]) << std::endl; + b_hat_k[j] = reduce_to_mod_F_n(Ab[j]); + std::cerr << "b_hat_k[" << j << "]: " << beautify(b_hat_k[j]) << std::endl; + } + + // Compute point-wise multiplication + std::vector c_hat_k{num_chunks, bvt{}}; + for(std::size_t j = 1; j < num_chunks; j += 2) + { + c_hat_k[j] = unsigned_multiplier(a_hat_k[j], b_hat_k[j]); + std::cerr << "c_hat_k[" << j << "]: " << beautify(c_hat_k[j]) << std::endl; + } + + // Apply inverse NTT + for(std::size_t nu = address_bits(num_chunks) - 1; nu > 0; --nu) + { + const std::size_t bit_nu_plus_1 = (std::size_t)1 << nu; + std::size_t bits_up_to_nu_plus_1 = 0; + for(std::size_t i = 0; i < nu; ++i) + bits_up_to_nu_plus_1 |= 1 << i; + + // we only need odd ones + for(std::size_t k = 1; k < num_chunks; k += 2) + { + if((k & bit_nu_plus_1) == 0) + continue; + + bvt c_hat_k_nu_plus_1_bit_is_zero = c_hat_k[k & ~bit_nu_plus_1]; + + c_hat_k[k & ~bit_nu_plus_1] = shift( + cyclic_add(c_hat_k_nu_plus_1_bit_is_zero, c_hat_k[k]), + shiftt::ROTATE_RIGHT, + 1); + std::cerr << "c_hat_k[" << nu << "](" << (k & ~bit_nu_plus_1) + << "): " << beautify(c_hat_k[k & ~bit_nu_plus_1]) << std::endl; + + const std::size_t chi = (k & bits_up_to_nu_plus_1) + << (address_bits(num_chunks) - 1 - nu); + const std::size_t omega = m % 2 == 1 ? 2 : 4; + const std::size_t shift_dist = chi * omega / 2 + 1; + std::cerr << "SHIFT: " << shift_dist << std::endl; + + c_hat_k[k] = shift( + cyclic_add( + c_hat_k_nu_plus_1_bit_is_zero, + shift(c_hat_k[k], shiftt::ROTATE_LEFT, (std::size_t)1 << n)), + shiftt::ROTATE_RIGHT, + shift_dist); + std::cerr << "c_hat_k[" << nu << "](" << k + << "): " << beautify(c_hat_k[k]) << std::endl; + } + } + // Reduce by F_n + std::vector z_j_mod_F_n; + z_j_mod_F_n.reserve(num_chunks / 2); + for(std::size_t j = 0; j < num_chunks / 2; ++j) + { + // reverse n - 1 (n - 2 if m is even) bits of j + std::size_t k = 0; + for(std::size_t nu = 0; nu < address_bits(num_chunks) - 1; ++nu) + { + k |= (j & (1 << nu)) >> nu; + k <<= 1; + } + k |= 1; + std::cerr << "j " << j << " maps to " << k << std::endl; + z_j_mod_F_n.push_back(reduce_to_mod_F_n(c_hat_k[k])); + std::cerr << "z_j_mod_F_n[" << j << "]: " << beautify(z_j_mod_F_n[j]) + << std::endl; + } + + // Compute final coefficients as eta + delta * F_n where delta = eta - xi for + // eta z_j and xi c_hat_k. + for(std::size_t j = 0; j < num_chunks / 2; ++j) + { + bvt eta = z_j_mod_F_n[j]; + std::cerr << "eta[" << j << "]: " << beautify(eta) << std::endl; + bvt xi = z_j[j]; + std::cerr << "xi[" << j << "]: " << beautify(xi) << std::endl; + // TODO: couldn't we do this over just xi.size() bits instead? + bvt delta = sub(eta, zero_extension(xi, eta.size())); + CHECK_RETURN(delta.size() >= xi.size()); + delta.resize(xi.size()); + std::cerr << "delta[" << j << "]: " << beautify(delta) << std::endl; + z_j[j] = add( + zero_extension(eta, two_to_m_plus_1), + add( + shift( + zero_extension(delta, two_to_m_plus_1), + shiftt::SHIFT_LEFT, + (std::size_t)1 << n), + zero_extension(delta, two_to_m_plus_1))); + std::cerr << "z_j[" << j << "]: " << beautify(z_j[j]) << std::endl; + } + + bvt result = zeros(two_to_m_plus_1); + for(std::size_t j = 0; j < num_chunks / 2; ++j) + { + if(chunk_size * j >= a.size()) + break; + result = add(result, shift(z_j[j], shiftt::SHIFT_LEFT, chunk_size * j)); + } + std::cerr << "result: " << beautify(result) << std::endl; + CHECK_RETURN(result.size() >= a.size()); + result.resize(a.size()); + std::cerr << "result resized: " << beautify(result) << std::endl; + + return result; +} + bvt bv_utilst::unsigned_multiplier_no_overflow( const bvt &op0, const bvt &op1) @@ -1012,7 +2624,15 @@ bvt bv_utilst::signed_multiplier(const bvt &op0, const bvt &op1) bvt neg0=cond_negate(op0, sign0); bvt neg1=cond_negate(op1, sign1); +#ifdef USE_KARATSUBA + bvt result = unsigned_karatsuba_multiplier(neg0, neg1); +#elif defined(USE_TOOM_COOK) + bvt result = unsigned_toom_cook_multiplier(neg0, neg1); +#elif defined(USE_SCHOENHAGE_STRASSEN) + bvt result = unsigned_schoenhage_strassen_multiplier(neg0, neg1); +#else bvt result=unsigned_multiplier(neg0, neg1); +#endif literalt result_sign=prop.lxor(sign0, sign1); @@ -1080,7 +2700,18 @@ bvt bv_utilst::multiplier( switch(rep) { case representationt::SIGNED: return signed_multiplier(op0, op1); +#ifdef USE_KARATSUBA + case representationt::UNSIGNED: + return unsigned_karatsuba_multiplier(op0, op1); +#elif defined(USE_TOOM_COOK) + case representationt::UNSIGNED: + return unsigned_toom_cook_multiplier(op0, op1); +#elif defined(USE_SCHOENHAGE_STRASSEN) + case representationt::UNSIGNED: + return unsigned_schoenhage_strassen_multiplier(op0, op1); +#else case representationt::UNSIGNED: return unsigned_multiplier(op0, op1); +#endif } UNREACHABLE; diff --git a/src/solvers/flattening/bv_utils.h b/src/solvers/flattening/bv_utils.h index 60ba1422b56..58de3dff4ff 100644 --- a/src/solvers/flattening/bv_utils.h +++ b/src/solvers/flattening/bv_utils.h @@ -79,6 +79,10 @@ class bv_utilst bvt shift(const bvt &op, const shiftt shift, const bvt &distance); bvt unsigned_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_karatsuba_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_karatsuba_full_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_toom_cook_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_schoenhage_strassen_multiplier(const bvt &a, const bvt &b); bvt signed_multiplier(const bvt &op0, const bvt &op1); bvt multiplier(const bvt &op0, const bvt &op1, representationt rep); bvt multiplier_no_overflow( @@ -250,6 +254,7 @@ class bv_utilst bvt wallace_tree(const std::vector &pps); bvt dadda_tree(const std::vector &pps); + bvt comba_column_wise(const std::vector &pps); }; #endif // CPROVER_SOLVERS_FLATTENING_BV_UTILS_H diff --git a/src/solvers/refinement/refine_arithmetic.cpp b/src/solvers/refinement/refine_arithmetic.cpp index 9ffb093e269..b0572694fb0 100644 --- a/src/solvers/refinement/refine_arithmetic.cpp +++ b/src/solvers/refinement/refine_arithmetic.cpp @@ -6,6 +6,16 @@ Author: Daniel Kroening, kroening@kroening.com \*******************************************************************/ +// REFINE_MULT_MODE selects the multiplication refinement strategy: +// 0 = original (full multiplier on first spurious result) +// 1 = assumption-gated (narrow multiplier first, then full) +// 2 = Karatsuba polynomial (k=2, 3 evaluation points, then full) +// 3 = Toom-Cook polynomial (k=n/4, evaluation at 0 and 1, then full) +// Default is 1 (assumption-gated) as it gives the best overall performance. +#ifndef REFINE_MULT_MODE +#define REFINE_MULT_MODE 1 +#endif + #include "bv_refinement.h" #include @@ -92,6 +102,90 @@ bvt bv_refinementt::convert_mult(const mult_exprt &expr) literalt res_op1=bv_utils.equal(a.op1_bv, a.result_bv); prop.l_set_to_true(prop.limplies(op0_one, res_op1)); prop.l_set_to_true(prop.limplies(op1_one, res_op0)); + + // Karatsuba-style setup (only for REFINE_MULT_MODE 2) +#if REFINE_MULT_MODE == 2 + // Split operands into two halves and create + // 3 non-deterministic coefficient variables d[0], d[1], d[2]. + // Result = d[0] + d[1] * 2^half + d[2] * 2^(2*half) (mod 2^n) + // Evaluation points constrain them: + // r(0): d[0] = a_lo * b_lo + // r(inf): d[2] = a_hi * b_hi + // r(1): d[0]+d[1]+d[2] = (a_lo+a_hi) * (b_lo+b_hi) + const std::size_t n = a.op0_bv.size(); + const std::size_t half = n / 2; + if(half >= 4) + { + // d_bits: wide enough for half*half product + const std::size_t d_bits = n + 2; + + // Create 3 non-deterministic coefficient variables + bvt d0 = prop.new_variables(d_bits); + bvt d1 = prop.new_variables(d_bits); + bvt d2 = prop.new_variables(d_bits); + + // Constrain: result = d[0] + d[1]< chunk) + { + const std::size_t num_a_chunks = (n + chunk - 1) / chunk; + const std::size_t num_d = 2 * num_a_chunks - 1; + const std::size_t d_bits = 2 * chunk + address_bits(num_a_chunks); + + std::vector d_coeffs; + d_coeffs.reserve(num_d); + for(std::size_t i = 0; i < num_d; ++i) + d_coeffs.push_back(prop.new_variables(d_bits)); + + bvt reconstructed = bv_utils.zeros(n); + for(std::size_t i = 0; i < num_d; ++i) + { + std::size_t shift_amt = i * chunk; + if(shift_amt >= n) + break; + bvt shifted = bv_utils.zero_extension(d_coeffs[i], n); + shifted = bv_utils.shift( + shifted, bv_utilst::shiftt::SHIFT_LEFT, shift_amt); + reconstructed = bv_utils.add(reconstructed, shifted); + } + bv_utils.set_equal(reconstructed, a.result_bv); + + a.op2_bv.clear(); + for(const auto &d : d_coeffs) + a.op2_bv.insert(a.op2_bv.end(), d.begin(), d.end()); + a.no_operands = 3; + } +#endif // REFINE_MULT_MODE == 3 } return bv; @@ -278,8 +372,16 @@ void bv_refinementt::check_SAT(approximationt &a) a.expr.operands().size() == 2, "all (un)signedbv typed exprs are binary"); // already full interpretation? - if(a.over_state>0) +#if REFINE_MULT_MODE == 0 + if(a.over_state > 0) + return; +#elif REFINE_MULT_MODE == 1 + if(a.over_state > 1) + return; +#else // REFINE_MULT_MODE == 2 or 3 + if(a.over_state > 3) return; +#endif bv_spect spec(type); bv_arithmetict o0(spec), o1(spec); @@ -301,37 +403,193 @@ void bv_refinementt::check_SAT(approximationt &a) else UNREACHABLE; - if(o0.pack()==a.result_value) // ok + if(o0.pack()==a.result_value && +#if REFINE_MULT_MODE == 0 + true +#elif REFINE_MULT_MODE == 1 + a.over_state > 0 +#else // mode 2 or 3 + a.over_state > 2 +#endif + ) return; - if(a.over_state==0) + auto rep = a.expr.type().id() == ID_signedbv + ? bv_utilst::representationt::SIGNED + : bv_utilst::representationt::UNSIGNED; + + if(false) {} // placeholder for #if chain +#if REFINE_MULT_MODE == 2 + else if(a.expr.id() == ID_mult && a.no_operands == 3 && a.over_state == 0) { - // we give up right away and add the full interpretation - bvt r; - if(a.expr.id()==ID_mult) + // Karatsuba r(0): d[0] = a_lo * b_lo + const std::size_t n = a.op0_bv.size(); + const std::size_t half = n / 2; + const std::size_t d_bits = n + 2; + + bvt d0(a.op2_bv.begin(), a.op2_bv.begin() + d_bits); + bvt a_lo(a.op0_bv.begin(), a.op0_bv.begin() + half); + bvt b_lo(a.op1_bv.begin(), a.op1_bv.begin() + half); + + bvt prod = bv_utils.unsigned_multiplier( + bv_utils.zero_extension(a_lo, d_bits), + bv_utils.zero_extension(b_lo, d_bits)); + bv_utils.set_equal(d0, prod); + } + else if(a.expr.id() == ID_mult && a.no_operands == 3 && a.over_state == 1) + { + // Karatsuba r(inf): d[2] = a_hi * b_hi + const std::size_t n = a.op0_bv.size(); + const std::size_t half = n / 2; + const std::size_t d_bits = n + 2; + + bvt d2(a.op2_bv.begin() + 2 * d_bits, + a.op2_bv.begin() + 3 * d_bits); + bvt a_hi(a.op0_bv.begin() + half, a.op0_bv.end()); + bvt b_hi(a.op1_bv.begin() + half, a.op1_bv.end()); + + bvt prod = bv_utils.unsigned_multiplier( + bv_utils.zero_extension(a_hi, d_bits), + bv_utils.zero_extension(b_hi, d_bits)); + bv_utils.set_equal(d2, prod); + } + else if(a.expr.id() == ID_mult && a.no_operands == 3 && a.over_state == 2) + { + // Karatsuba r(1): d[0]+d[1]+d[2] = (a_lo+a_hi) * (b_lo+b_hi) + const std::size_t n = a.op0_bv.size(); + const std::size_t half = n / 2; + const std::size_t d_bits = n + 2; + const std::size_t eval_bits = d_bits + 2; + + bvt d0(a.op2_bv.begin(), a.op2_bv.begin() + d_bits); + bvt d1(a.op2_bv.begin() + d_bits, a.op2_bv.begin() + 2 * d_bits); + bvt d2(a.op2_bv.begin() + 2 * d_bits, a.op2_bv.begin() + 3 * d_bits); + + bvt d_sum = bv_utils.add( + bv_utils.zero_extension(d0, eval_bits), + bv_utils.add( + bv_utils.zero_extension(d1, eval_bits), + bv_utils.zero_extension(d2, eval_bits))); + + bvt a_lo(a.op0_bv.begin(), a.op0_bv.begin() + half); + bvt a_hi(a.op0_bv.begin() + half, a.op0_bv.end()); + bvt b_lo(a.op1_bv.begin(), a.op1_bv.begin() + half); + bvt b_hi(a.op1_bv.begin() + half, a.op1_bv.end()); + + bvt a_sum = bv_utils.add( + bv_utils.zero_extension(a_lo, eval_bits), + bv_utils.zero_extension(a_hi, eval_bits)); + bvt b_sum = bv_utils.add( + bv_utils.zero_extension(b_lo, eval_bits), + bv_utils.zero_extension(b_hi, eval_bits)); + + bvt prod = bv_utils.unsigned_multiplier(a_sum, b_sum); + bv_utils.set_equal( + bv_utils.zero_extension(d_sum, prod.size()), prod); + } +#endif // REFINE_MULT_MODE == 2 +#if REFINE_MULT_MODE == 3 + else if(a.expr.id() == ID_mult && a.no_operands == 3 && a.over_state == 0) + { + // Toom-Cook r(0): d[0] = a[0] * b[0] + const std::size_t n = a.op0_bv.size(); + const std::size_t chunk = 4; + const std::size_t num_a_chunks = (n + chunk - 1) / chunk; + const std::size_t d_bits = 2 * chunk + address_bits(num_a_chunks); + + bvt d0(a.op2_bv.begin(), a.op2_bv.begin() + d_bits); + bvt a0(a.op0_bv.begin(), + a.op0_bv.begin() + std::min(chunk, n)); + bvt b0(a.op1_bv.begin(), + a.op1_bv.begin() + std::min(chunk, n)); + + bvt prod = bv_utils.unsigned_multiplier( + bv_utils.zero_extension(a0, d_bits), + bv_utils.zero_extension(b0, d_bits)); + bv_utils.set_equal(d0, prod); + } + else if(a.expr.id() == ID_mult && a.no_operands == 3 && a.over_state == 1) + { + // Toom-Cook r(1): sum(d[i]) = sum(a_chunks) * sum(b_chunks) + const std::size_t n = a.op0_bv.size(); + const std::size_t chunk = 4; + const std::size_t num_a_chunks = (n + chunk - 1) / chunk; + const std::size_t num_d = 2 * num_a_chunks - 1; + const std::size_t d_bits = 2 * chunk + address_bits(num_a_chunks); + const std::size_t eval_bits = d_bits + address_bits(num_d); + + bvt a_sum = bv_utils.zeros(eval_bits); + bvt b_sum = bv_utils.zeros(eval_bits); + for(std::size_t i = 0; i < num_a_chunks; ++i) + { + std::size_t lo = i * chunk; + std::size_t hi = std::min(lo + chunk, n); + bvt ai(a.op0_bv.begin() + lo, a.op0_bv.begin() + hi); + bvt bi(a.op1_bv.begin() + lo, a.op1_bv.begin() + hi); + a_sum = bv_utils.add(a_sum, bv_utils.zero_extension(ai, eval_bits)); + b_sum = bv_utils.add(b_sum, bv_utils.zero_extension(bi, eval_bits)); + } + + bvt d_sum = bv_utils.zeros(eval_bits); + for(std::size_t i = 0; i < num_d; ++i) { - r=bv_utils.multiplier( - a.op0_bv, a.op1_bv, - a.expr.type().id()==ID_signedbv? - bv_utilst::representationt::SIGNED: - bv_utilst::representationt::UNSIGNED); + bvt di(a.op2_bv.begin() + i * d_bits, + a.op2_bv.begin() + (i + 1) * d_bits); + d_sum = bv_utils.add(d_sum, bv_utils.zero_extension(di, eval_bits)); } - else if(a.expr.id()==ID_div) + + bvt prod = bv_utils.unsigned_multiplier(a_sum, b_sum); + bv_utils.set_equal( + bv_utils.zero_extension(d_sum, prod.size()), prod); + } +#endif // REFINE_MULT_MODE == 3 +#if REFINE_MULT_MODE == 1 + else if(a.expr.id() == ID_mult && a.over_state == 0) + { + // Assumption-gated: narrow multiplier first + const std::size_t n = a.op0_bv.size(); + const std::size_t k = std::min(std::size_t(4), n); + + if(k < n) { - r=bv_utils.divider( - a.op0_bv, a.op1_bv, - a.expr.type().id()==ID_signedbv? - bv_utilst::representationt::SIGNED: - bv_utilst::representationt::UNSIGNED); + a.over_assumptions.clear(); + bvt a_low(a.op0_bv.begin(), a.op0_bv.begin() + k); + bvt b_low(a.op1_bv.begin(), a.op1_bv.begin() + k); + bvt r_approx = bv_utils.multiplier( + bv_utils.zero_extension(a_low, n), + bv_utils.zero_extension(b_low, n), + rep); + + literalt gate = prop.new_variable(); + // Only constrain the low k bits — leave high bits free. + // This is an over-approximation: the low bits are exact, + // the high bits can take any value. + for(std::size_t i = 0; i < k && i < r_approx.size(); ++i) + { + prop.lcnf(!gate, !a.result_bv[i], r_approx[i]); + prop.lcnf(!gate, a.result_bv[i], !r_approx[i]); + } + a.add_over_assumption(gate); } - else if(a.expr.id()==ID_mod) + else { - r=bv_utils.remainder( - a.op0_bv, a.op1_bv, - a.expr.type().id()==ID_signedbv? - bv_utilst::representationt::SIGNED: - bv_utilst::representationt::UNSIGNED); + bv_utils.set_equal( + bv_utils.multiplier(a.op0_bv, a.op1_bv, rep), a.result_bv); } + } +#endif // REFINE_MULT_MODE == 1 + else if(a.over_state <= 3) + { + // Full interpretation (fallback for all modes, also div/mod) + a.over_assumptions.clear(); + + bvt r; + if(a.expr.id() == ID_mult) + r = bv_utils.multiplier(a.op0_bv, a.op1_bv, rep); + else if(a.expr.id() == ID_div) + r = bv_utils.divider(a.op0_bv, a.op1_bv, rep); + else if(a.expr.id() == ID_mod) + r = bv_utils.remainder(a.op0_bv, a.op1_bv, rep); else UNREACHABLE; diff --git a/src/util/simplify_expr_int.cpp b/src/util/simplify_expr_int.cpp index f789bfdfb37..c5a2707159a 100644 --- a/src/util/simplify_expr_int.cpp +++ b/src/util/simplify_expr_int.cpp @@ -1362,6 +1362,176 @@ simplify_exprt::simplify_inequality(const binary_relation_exprt &expr) if(tmp0.type() != tmp1.type()) return unchanged(expr); + // Check for commutative equivalence: a op b == b op a + // for commutative operators (mult, plus, bitand, bitor, bitxor) + if(expr.id() == ID_equal || expr.id() == ID_notequal) + { + if( + tmp0.id() == tmp1.id() && tmp0.operands().size() == 2 && + tmp1.operands().size() == 2 && + (tmp0.id() == ID_mult || tmp0.id() == ID_plus || + tmp0.id() == ID_bitand || tmp0.id() == ID_bitor || + tmp0.id() == ID_bitxor)) + { + if( + tmp0.operands()[0] == tmp1.operands()[1] && + tmp0.operands()[1] == tmp1.operands()[0]) + { + if(expr.id() == ID_equal) + return true_exprt(); + else + return false_exprt(); + } + } + + // Check for distributivity: a * (b + c) == a * b + a * c + // Try expanding multiplication over addition on each side and compare. + auto distribute_mult = [](const exprt &e) -> std::optional { + if(e.id() != ID_mult || e.operands().size() != 2) + return {}; + for(int i = 0; i < 2; ++i) + { + const exprt &factor = e.operands()[i]; + const exprt &sum = e.operands()[1 - i]; + if(sum.id() == ID_plus && sum.operands().size() == 2) + { + // a * (b + c) -> a*b + a*c + mult_exprt prod0(factor, sum.operands()[0]); + prod0.type() = e.type(); + mult_exprt prod1(factor, sum.operands()[1]); + prod1.type() = e.type(); + plus_exprt result(std::move(prod0), std::move(prod1)); + result.type() = e.type(); + return std::move(result); + } + } + return {}; + }; + + // Try expanding LHS and comparing with RHS, and vice versa + auto expanded0 = distribute_mult(tmp0); + auto expanded1 = distribute_mult(tmp1); + + // Compare expanded forms, accounting for commutativity of + and * + auto comm_equal = [](const exprt &a, const exprt &b) -> bool { + if(a == b) + return true; + // Check commutativity of the top-level operator + if( + a.id() == b.id() && a.operands().size() == 2 && + b.operands().size() == 2 && + (a.id() == ID_plus || a.id() == ID_mult)) + { + if( + a.operands()[0] == b.operands()[1] && + a.operands()[1] == b.operands()[0]) + return true; + } + return false; + }; + + // Also check commutativity within sub-expressions + auto deep_comm_equal = [&comm_equal](const exprt &a, const exprt &b) { + if(comm_equal(a, b)) + return true; + // If both are plus with 2 operands, check if individual products + // match with commutativity + if( + a.id() == ID_plus && b.id() == ID_plus && + a.operands().size() == 2 && b.operands().size() == 2) + { + auto prod_eq = [](const exprt &p, const exprt &q) { + if(p == q) + return true; + if( + p.id() == ID_mult && q.id() == ID_mult && + p.operands().size() == 2 && q.operands().size() == 2) + { + return p.operands()[0] == q.operands()[1] && + p.operands()[1] == q.operands()[0]; + } + return false; + }; + return (prod_eq(a.operands()[0], b.operands()[0]) && + prod_eq(a.operands()[1], b.operands()[1])) || + (prod_eq(a.operands()[0], b.operands()[1]) && + prod_eq(a.operands()[1], b.operands()[0])); + } + return false; + }; + + bool is_equal = false; + if(expanded0.has_value() && deep_comm_equal(*expanded0, tmp1)) + is_equal = true; + else if(expanded1.has_value() && deep_comm_equal(tmp0, *expanded1)) + is_equal = true; + else if( + expanded0.has_value() && expanded1.has_value() && + deep_comm_equal(*expanded0, *expanded1)) + is_equal = true; + + if(is_equal) + { + if(expr.id() == ID_equal) + return true_exprt(); + else + return false_exprt(); + } + + // Check for associative+commutative equivalence. + // Flatten nested applications of the same operator and compare + // the multisets of leaves. E.g., (a*b)*c == a*(b*c) both flatten + // to {a, b, c} under *. + if( + tmp0.operands().size() == 2 && tmp1.operands().size() == 2 && + (tmp0.id() == ID_mult || tmp0.id() == ID_plus || + tmp0.id() == ID_bitand || tmp0.id() == ID_bitor || + tmp0.id() == ID_bitxor) && + (tmp1.id() == tmp0.id() || + // one side has the op at top, other has it nested + (tmp0.operands().size() == 2 && tmp1.operands().size() == 2))) + { + const irep_idt &op_id = tmp0.id(); + // Only proceed if both sides use the same top-level operator + // or one side has it nested + auto flatten = [&op_id](const exprt &e, std::vector &leaves) { + std::vector worklist = {&e}; + while(!worklist.empty()) + { + const exprt *cur = worklist.back(); + worklist.pop_back(); + if(cur->id() == op_id && cur->operands().size() == 2) + { + worklist.push_back(&cur->operands()[0]); + worklist.push_back(&cur->operands()[1]); + } + else + leaves.push_back(*cur); + } + }; + + if(tmp0.id() == op_id && tmp1.id() == op_id) + { + std::vector leaves0, leaves1; + flatten(tmp0, leaves0); + flatten(tmp1, leaves1); + + if(leaves0.size() == leaves1.size() && leaves0.size() <= 4) + { + std::sort(leaves0.begin(), leaves0.end()); + std::sort(leaves1.begin(), leaves1.end()); + if(leaves0 == leaves1) + { + if(expr.id() == ID_equal) + return true_exprt(); + else + return false_exprt(); + } + } + } + } + } + // if rhs is ID_if (and lhs is not), swap operands for == and != if((expr.id()==ID_equal || expr.id()==ID_notequal) && tmp0.id()!=ID_if &&