Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Use wraparound math for intermediate vals when desired #140

Open
matth2k opened this issue Mar 13, 2024 · 0 comments
Open

[Feature] Use wraparound math for intermediate vals when desired #140

matth2k opened this issue Mar 13, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@matth2k
Copy link

matth2k commented Mar 13, 2024

Is your feature request related to a problem? Please describe.
Yes, arbitrary precision types for intermediate values are causing hitches in the AMC flow. For example, in #121 the typecasting was causing extra buffers and copy loops in the IR which increased the latency of designs. In this latest case, I am trying to evaluate a polynomial. The majority of the IR becomes dedicated to extending numbers in order to create arith ops that are large enough to hold overflow bits. It would be nice if we could toggle this behavior to simply use wraparound when both the lval and rval types match.

Here is an example of the current behavior:

def test_tay_approximation(printResEst=False):
    N = 16

    def kernel_approx(x: int32[N]) -> int32[N]:
        y: int32[N] = 0
        for i in allo.grid(N):
            y[i] = (
                x[i] * 10000
                - (x[i] * x[i] * x[i] * 17) * 10
                - (x[i] * x[i] * x[i] * x[i] * x[i] * 11)
            )
        return y

    s = allo.customize(kernel_approx)
    print(s.module)

The output is

module {
  func.func @kernel_approx(%arg0: memref<16xi32>) -> memref<16xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "y"} : memref<16xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<16xi32>)
    affine.for %arg1 = 0 to 16 {
      %0 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %1 = arith.extsi %0 : i32 to i64
      %c10000_i32 = arith.constant 10000 : i32
      %2 = arith.extsi %c10000_i32 : i32 to i64
      %3 = arith.muli %1, %2 : i64
      %4 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %5 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %6 = arith.extsi %4 : i32 to i64
      %7 = arith.extsi %5 : i32 to i64
      %8 = arith.muli %6, %7 : i64
      %9 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %10 = arith.extsi %8 : i64 to i96
      %11 = arith.extsi %9 : i32 to i96
      %12 = arith.muli %10, %11 : i96
      %13 = arith.extsi %12 : i96 to i128
      %c17_i32 = arith.constant 17 : i32
      %14 = arith.extsi %c17_i32 : i32 to i128
      %15 = arith.muli %13, %14 : i128
      %16 = arith.extsi %15 : i128 to i160
      %c10_i32 = arith.constant 10 : i32
      %17 = arith.extsi %c10_i32 : i32 to i160
      %18 = arith.muli %16, %17 : i160
      %19 = arith.extsi %3 : i64 to i161
      %20 = arith.extsi %18 : i160 to i161
      %21 = arith.subi %19, %20 : i161
      %22 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %23 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %24 = arith.extsi %22 : i32 to i64
      %25 = arith.extsi %23 : i32 to i64
      %26 = arith.muli %24, %25 : i64
      %27 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %28 = arith.extsi %26 : i64 to i96
      %29 = arith.extsi %27 : i32 to i96
      %30 = arith.muli %28, %29 : i96
      %31 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %32 = arith.extsi %30 : i96 to i128
      %33 = arith.extsi %31 : i32 to i128
      %34 = arith.muli %32, %33 : i128
      %35 = affine.load %arg0[%arg1] {from = "x"} : memref<16xi32>
      %36 = arith.extsi %34 : i128 to i160
      %37 = arith.extsi %35 : i32 to i160
      %38 = arith.muli %36, %37 : i160
      %39 = arith.extsi %38 : i160 to i192
      %c11_i32 = arith.constant 11 : i32
      %40 = arith.extsi %c11_i32 : i32 to i192
      %41 = arith.muli %39, %40 : i192
      %42 = arith.extsi %21 : i161 to i193
      %43 = arith.extsi %41 : i192 to i193
      %44 = arith.subi %42, %43 : i193
      %45 = arith.trunci %44 : i193 to i32
      affine.store %45, %alloc[%arg1] {to = "y"} : memref<16xi32>
    } {loop_name = "i", op_name = "S_i_0"}
    return %alloc : memref<16xi32>
  }
}

Describe the solution you'd like
In the end, I would expect Vivado to be able to optimize away a lot of these wires. However, toggling this behavior has some practical benefits for us: reduces the number of nets in waveform dumps for debugging, Calyx has some trouble supporting ap types wider than 64b.

@matth2k matth2k added the enhancement New feature or request label Mar 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant