Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]FFN memory usage not optmized due to difficulty of fusing linalg ops #90

Open
mmengjiadai opened this issue Oct 12, 2023 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@mmengjiadai
Copy link
Contributor

Describe the bug
A linear layer or activation function should allocate a maximum of 1 memory block. And there are possibilities of reuse between different layers. However, current builder allocates a memory block each for operations such as transpose or broadcast. Supposedly the overhead can be eliminated by fusing linalg.transpose with linalg.fill and linalg.matmul. I tried rewriting mlp.py using for loops and succeeded in using only 2 allocations.

To Reproduce
Run mlp.py with monitor_memory and without enable_tensor. The total number of allocations is ten.
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| name | shape | dtype | mem(bits) | BRAM(18K) | store counts | data storage |
+===========+==========+=========+=============+=============+================+============================================================================+
| %alloc | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %0[%arg1, %arg2] : memref<30x30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_3 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %8 = arith.addf %6, %7 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_10 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %1[%arg2] : memref<30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_14 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %6 = arith.addf %4, %5 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_21 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %2[%arg1, %arg2] : memref<30x30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_28 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %8 = arith.addf %6, %7 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_35 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %4 = memref.load %3[%arg2] : memref<30xf32> |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_39 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %6 = arith.addf %4, %5 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_46 | [30, 30] | f32 | 28800 | 1.6384e+06 | 1 | %6 = arith.maxf %4, %5 : f32 |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| %alloc_50 | [30, 30] | f32 | 28800 | 1.6384e+06 | 0 | |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+
| Total(10) | | | 288000 | 1.6384e+07 | | *data storage: data stored into an allocated memory. Doesn't include init. |
+-----------+----------+---------+-------------+-------------+----------------+----------------------------------------------------------------------------+

Expected behavior
If rewriting the FFN with for loops:

def test_for_FFN():
    w1_const = np.float32(np.random.uniform(size=(30, 30)))
    w2_const = np.float32(np.random.uniform(size=(30, 30)))
    b1_const = np.float32(np.random.uniform(size=(30)))
    b2_const = np.float32(np.random.uniform(size=(30)))
    def kernel(A: float32[30, 30]) -> float32[30, 30]:
        w1: float32[30, 30] = w1_const
        w2: float32[30, 30] = w2_const
        b1: float32[30] = b1_const
        b2: float32[30] = b2_const
        B: float32[30, 30] = 0
        for i, j in allo.grid(30, 30):
            for k in allo.reduction(30):
                B[i, j] += A[k, i] * w1[k, j]
                B[i, j] += b1[j]
        C: float32[30, 30] = 0
        for i, j in allo.grid(30, 30):
            for k in allo.reduction(30):
                C[i, j] += B[k, i] * w2[k, j]
                C[i, j] += b2[j]
        for i, j in allo.grid(30, 30):
            C[i, j] = allo.max(C[i, j], 0.0)
        return C

    s = allo.customize(kernel, verbose=True)
    print(s.module)
    mod = s.build()
    monitor_memory_table = monitor_memory_usage(mod.intermediate_module)
    print(monitor_memory_table)

There are only 2 allocations, which means our current builder is way from optimal.

@mmengjiadai mmengjiadai added the bug Something isn't working label Oct 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants