Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ hgemm ] Add experimental kernel #2693

Merged

Conversation

skykongkong8
Copy link
Member

@skykongkong8 skykongkong8 commented Aug 2, 2024

  • According to current paper, accumulating up to 64 ~ 128 w.r.t. K-direction is fine.
  • Since conventional error metric, and newly introduced metric (max component relative error) is fine as well, introduce experiemntal kernel that is faster, but less accurate.
  • using build option -Dhgemm-experimental-kernel=true can enable such kernel when android build.
  • Performance of this kernel is shown at [ hgemm ] Improve transposed B matrix computation and matrix padding @open seasame 07/15 20:33 #2655
  • Include experimental kernel and exclude the previous one when experimental kernel build -> no newly added #ifdefs in file

- Forcibly adding zero-padding made small dim index quite clumsy and redundant.
- Implement explicit hgemm small function to cover M<8, N<16, K<16 case

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
@taos-ci
Copy link
Collaborator

taos-ci commented Aug 2, 2024

📝 TAOS-CI Version: 1.5.20200925. Thank you for submitting PR #2693. Please a submit 1commit/1PR (one commit per one PR) policy to get comments quickly from reviewers. Your PR must pass all verificiation processes of cibot before starting a review process from reviewers. If you are new member to join this project, please read manuals in documentation folder and wiki page. In order to monitor a progress status of your PR in more detail, visit http://ci.nnstreamer.ai/.

@skykongkong8 skykongkong8 changed the title [ hgemm ] Addd experimental kernel [ hgemm ] Add experimental kernel Aug 2, 2024
Copy link
Collaborator

@taos-ci taos-ci left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skykongkong8, 💯 All CI checkers are successfully verified. Thanks.

@skykongkong8 skykongkong8 force-pushed the pr/hgemm/experimental_kernel branch 3 times, most recently from 584e79f to e0d1bb4 Compare August 2, 2024 04:05
Copy link
Collaborator

@taos-ci taos-ci left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skykongkong8, 💯 All CI checkers are successfully verified. Thanks.

Copy link
Contributor

@djeong20 djeong20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, awesome work! One thing I want to point out is the newly added macros.

  1. There's room for the code to be reduced.
  2. I think four macros can be combined into one.

How about replacing it as follows?

// N would be 1, 4, 8, and 16
#define KERNEL_8x16_ACC(N)                               \
  do {                                                   \
    for (int i = 0; i < N; ++i) {                        \
      va0 = vld1q_f16(a + 8 * i);                        \
      vb1 = vld1q_f16(b + 16 * i);                       \
      vb2 = vld1q_f16(b + 16 * i + 8);                   \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
    }                                                    \
                                                         \
    l += N;                                              \
    __builtin_prefetch(b + 16 * N, 0, 3);                \
    __builtin_prefetch(a + 8 * N, 0, 3);                 \
    b += 16 * N;                                         \
    a += 8 * N;                                          \
  } while (0);

If this works, it can reduce about 600 lines of code. would this be valid?

nntrainer/tensor/hgemm/hgemm.h Show resolved Hide resolved
} while (0)

// 1. Partial sum 2048 digits
#define KERNEL_8x16_ACC16() \
Copy link
Contributor

@djeong20 djeong20 Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this entire macro can be reduced.

Suggested change
#define KERNEL_8x16_ACC16() \
for (int i = 0; i < 16; ++i) {
va0 = vld1q_f16(a + 8 * n);
vb1 = vld1q_f16(b + 8 * n * 2);
vb2 = vld1q_f16(b + 8 * (n * 2 + 1));
v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);
v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);
v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);
v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);
v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);
v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);
v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);
v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);
v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);
v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);
v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);
v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);
v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);
v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5);
v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6);
v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7);
}

} while (0)

// 2. Partial sum 1024 digits
#define KERNEL_8x16_ACC8() \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

nntrainer/tensor/hgemm/hgemm_util.cpp Show resolved Hide resolved
tools/package_android.sh Outdated Show resolved Hide resolved
@skykongkong8
Copy link
Member Author

skykongkong8 commented Aug 7, 2024

Overall, awesome work! One thing I want to point out is the newly added macros.

  1. There's room for the code to be reduced.
  2. I think four macros can be combined into one.

How about replacing it as follows?

// N would be 1, 4, 8, and 16
#define KERNEL_8x16_ACC(N)                               \
  do {                                                   \
    for (int i = 0; i < N; ++i) {                        \
      va0 = vld1q_f16(a + 8 * i);                        \
      vb1 = vld1q_f16(b + 16 * i);                       \
      vb2 = vld1q_f16(b + 16 * i + 8);                   \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
    }                                                    \
                                                         \
    l += N;                                              \
    __builtin_prefetch(b + 16 * N, 0, 3);                \
    __builtin_prefetch(a + 8 * N, 0, 3);                 \
    b += 16 * N;                                         \
    a += 8 * N;                                          \
  } while (0);

If this works, it can reduce about 600 lines of code. would this be valid?

This would be invalid according to my research. For some reason, many kernels are designed up to write hard-coded lines rather than looping through them. Some says it would make the kernel efficiency differ from time to time. So I just coded this way.
Observing GEMM kernels in XNNPACK, or OpenBLAS would help you with this.

- According to current paper, accumulating up to 64 ~ 128 w.r.t. K-direction is fine.
- Since conventional error metric, and newly introduced metric (max component relative error) is fine as well, introduce experiemntal kernel.
- using build option -Dhgemm-experimental-kernel=true can enable such kernel when android build

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
- Add missing doxtgen tags : transpose boolean params
- error message : emit error when try to use full-fp16 kernel with experimental kernel build

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
@skykongkong8 skykongkong8 force-pushed the pr/hgemm/experimental_kernel branch 2 times, most recently from 54f059e to 8a57bfb Compare August 7, 2024 01:33
@taos-ci
Copy link
Collaborator

taos-ci commented Aug 7, 2024

:octocat: cibot: @skykongkong8, nntrainer/tensor/hgemm/hgemm_util.h does not include Doxygen tags such as @file @brief @author @bug. You must include the Doxygen tags in the source code. Please refer to a Doxygen manual at http://github.com/nnstreamer/TAOS-CI/blob/main/ci/doc/doxygen-documentation.md

- get_prev_mltpl_of_2p_n is frequently used in many hgemm kernels.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
@djeong20
Copy link
Contributor

djeong20 commented Aug 7, 2024

This would be invalid according to my research. For some reason, many kernels are designed up to write hard-coded lines rather than looping through them. Some says it would make the kernel efficiency differ from time to time. So I just coded this way. Observing GEMM kernels in XNNPACK, or OpenBLAS would help you with this.

I don't quite understand why hard-coded kernels would make a difference. For this particular kernel, hardcoded operations are exactly the same as the operations processed by a loop. could you explain why it would make kernel efficiency different?

@skykongkong8
Copy link
Member Author

skykongkong8 commented Aug 7, 2024

This would be invalid according to my research. For some reason, many kernels are designed up to write hard-coded lines rather than looping through them. Some says it would make the kernel efficiency differ from time to time. So I just coded this way. Observing GEMM kernels in XNNPACK, or OpenBLAS would help you with this.

I don't quite understand why hard-coded kernels would make a difference. For this particular kernel, hardcoded operations are exactly the same as the operations processed by a loop. could you explain why it would make kernel efficiency different?

  1. It shows difference from latency measurement the unittest.
  2. Using fixed-sized macro would help the compiler to optimize itself when browsing through gemm matrix memory area

And many libraries are actually doing things like this.
For example here :

.macro KERNEL4x4_I
	pld	[ BO , #B_PRE ]
	fldd	d8 , [ BO ]
	fldd	d0 , [ AO ]
	pld	[ AO , #A_PRE ]

	fldd	d1 , [ AO, #8 ]
	fmuld	d16  , d0,  d8
	fldd	d2 , [ AO, #16 ]
	fmuld	d17  , d1,  d8
	fldd	d3 , [ AO, #24 ]
	fmuld	d18  , d2,  d8
	fldd	d9 , [ BO, #8 ]
	fmuld	d19  , d3,  d8

	fldd	d10, [ BO, #16 ]
	fmuld	d20  , d0,  d9
	fldd	d11, [ BO, #24 ]
	fmuld	d21  , d1,  d9
	add	BO , BO, #32
	add	AO , AO, #32
	fmuld	d22  , d2,  d9

	pld	[ BO , #B_PRE ]
	fldd	d12, [ BO ]
	fmuld	d23  , d3,  d9

	pld	[ AO , #A_PRE ]
	fldd	d4 , [ AO, #0 ]
	fmuld	d24  , d0,  d10
	fldd	d5 , [ AO, #8 ]
	fmuld	d25  , d1,  d10
	fldd	d6 , [ AO, #16 ]
	fmuld	d26  , d2,  d10
	fldd	d7 , [ AO, #24 ]
	fmuld	d27  , d3,  d10

	fldd	d13, [ BO, #8 ]
	fmuld	d28  , d0,  d11
	fldd	d14, [ BO, #16 ]
	fmuld	d29  , d1,  d11
	fldd	d15, [ BO, #24 ]
	fmuld	d30  , d2,  d11
	fmuld	d31  , d3,  d11

.endm



.macro KERNEL4x4_M2

	fmacd	d16  , d4,  d12
	pld	[ AO , #A_PRE+32 ]
	fmacd	d17  , d5,  d12
	fldd	d0 , [ AO , #32 ]
	fmacd	d18  , d6,  d12
	pld	[ BO , #B_PRE+32 ]
	fmacd	d19  , d7,  d12

	fldd	d8 , [ BO , #32 ]
	fmacd	d20  , d4,  d13
	fldd	d1 , [ AO, #40 ]
	fmacd	d21  , d5,  d13
	fldd	d2 , [ AO, #48 ]
	fmacd	d22  , d6,  d13
	fldd	d3 , [ AO, #56 ]
	fmacd	d23  , d7,  d13

	fmacd	d24  , d4,  d14
	fmacd	d25  , d5,  d14
	fldd	d9 , [ BO, #40 ]
	fmacd	d26  , d6,  d14
	fldd	d10, [ BO, #48 ]
	fmacd	d27  , d7,  d14

	fldd	d11, [ BO, #56 ]
	fmacd	d28  , d4,  d15
	fmacd	d29  , d5,  d15
	add	AO , AO, #64
	fmacd	d30  , d6,  d15
	add	BO , BO, #64
	fmacd	d31  , d7,  d15

.endm

They use fixed-sized kernels for optimal performance.

Copy link
Collaborator

@taos-ci taos-ci left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skykongkong8, 💯 All CI checkers are successfully verified. Thanks.

@djeong20
Copy link
Contributor

djeong20 commented Aug 7, 2024

I think this is a slightly different story. KERNEL4x4_I and KERNEL4x4_M2 both perform multiplication but the logic differs. KERNEL4x4_M2 performs Floating Multiply-Add with fmacd while KERNEL4x4_I manually does it by fldd and fmuld. That's why I think it needs to be separately defined although they have the same size 4x4.
However, my point is that why can't it reuse the same kernel if they are sharing the logic?
In XNNPack for instance, it shares the same f16_gemm() with passing M and N.

static void f16_gemm_1x16__asm_aarch64_neonfp16arith_ld32(benchmark::State& state, const char* net) {
    f16_gemm(state,
      xnn_f16_gemm_minmax_ukernel_1x16__asm_aarch64_neonfp16arith_ld32,
      xnn_init_f16_minmax_fp16arith_params,
      /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1,
      benchmark::utils::CheckNEONFP16ARITH);
  }
static void f16_gemm_4x16__asm_aarch64_neonfp16arith_ld32(benchmark::State& state, const char* net) {
    f16_gemm(state,
      xnn_f16_gemm_minmax_ukernel_4x16__asm_aarch64_neonfp16arith_ld32,
      xnn_init_f16_minmax_fp16arith_params,
      /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1,
      benchmark::utils::CheckNEONFP16ARITH);
  }
  static void f16_gemm_6x16__asm_aarch64_neonfp16arith_ld32(benchmark::State& state, const char* net) {
    f16_gemm(state,
      xnn_f16_gemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_ld32,
      xnn_init_f16_minmax_fp16arith_params,
      /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1,
      benchmark::utils::CheckNEONFP16ARITH);
  }

If you look inside the f16_gemm(), it uses M x N as a variable in a function while using the same gemm kernel.

static void f16_gemm(benchmark::State& state,
...
  for (uint32_t m = 0; m < mc; m += mr) {
      const uint32_t mb = min(mc - m, mr);
      for (uint32_t n = 0; n < nc; n += nr) {
        const uint32_t nb = min(nc - n, nr);
        gemm(
          mb, nb, kc * sizeof(uint16_t),
          a.data() + m * kc, kc * sizeof(uint16_t),
          w.data() + (nc_stride * buffer_index + n) * (kc_stride + 1),
          c.data() + (mc * buffer_index + m) * nc + n, nc * sizeof(uint16_t), nr * sizeof(uint16_t),
          &params);
      }
  }

@djeong20
Copy link
Contributor

djeong20 commented Aug 7, 2024

also, if using a for loop impacts performance, why not take i as an input?

// N would be 1, 4, 8, and 16
#define KERNEL_8x16_ACC(N, i)                               \
  do {                                                   \
      va0 = vld1q_f16(a + 8 * i);                        \
      vb1 = vld1q_f16(b + 16 * i);                       \
      vb2 = vld1q_f16(b + 16 * i + 8);                   \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
                                                         \
                                                         \
    l += N;                                              \
    __builtin_prefetch(b + 16 * N, 0, 3);                \
    __builtin_prefetch(a + 8 * N, 0, 3);                 \
    b += 16 * N;                                         \
    a += 8 * N;                                          \
  } while (0);
KERNEL_8x16_ACC(N, 0)
KERNEL_8x16_ACC(N, 1)
/// and so on

@skykongkong8
Copy link
Member Author

skykongkong8 commented Aug 7, 2024

also, if using a for loop impacts performance, why not take i as an input?

// N would be 1, 4, 8, and 16
#define KERNEL_8x16_ACC(N, i)                               \
  do {                                                   \
      va0 = vld1q_f16(a + 8 * i);                        \
      vb1 = vld1q_f16(b + 16 * i);                       \
      vb2 = vld1q_f16(b + 16 * i + 8);                   \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
                                                         \
                                                         \
    l += N;                                              \
    __builtin_prefetch(b + 16 * N, 0, 3);                \
    __builtin_prefetch(a + 8 * N, 0, 3);                 \
    b += 16 * N;                                         \
    a += 8 * N;                                          \
  } while (0);
KERNEL_8x16_ACC(N, 0)
KERNEL_8x16_ACC(N, 1)
/// and so on

This is good point! I will try on my local to check whether it hinders the kernel performance.

@skykongkong8
Copy link
Member Author

skykongkong8 commented Aug 7, 2024

But still, kernel design in here as I observe, (it wasn't about difference b/w I and M2)

.macro KERNEL4x4_M2

	fmacd	d16  , d4,  d12
	pld	[ AO , #A_PRE+32 ]
	fmacd	d17  , d5,  d12
	fldd	d0 , [ AO , #32 ]
	fmacd	d18  , d6,  d12
	pld	[ BO , #B_PRE+32 ]
	fmacd	d19  , d7,  d12

	fldd	d8 , [ BO , #32 ]
	fmacd	d20  , d4,  d13
	fldd	d1 , [ AO, #40 ]
	fmacd	d21  , d5,  d13
	fldd	d2 , [ AO, #48 ]
	fmacd	d22  , d6,  d13
	fldd	d3 , [ AO, #56 ]
	fmacd	d23  , d7,  d13

	fmacd	d24  , d4,  d14
	fmacd	d25  , d5,  d14
	fldd	d9 , [ BO, #40 ]
	fmacd	d26  , d6,  d14
	fldd	d10, [ BO, #48 ]
	fmacd	d27  , d7,  d14

	fldd	d11, [ BO, #56 ]
	fmacd	d28  , d4,  d15
	fmacd	d29  , d5,  d15
	add	AO , AO, #64
	fmacd	d30  , d6,  d15
	add	BO , BO, #64
	fmacd	d31  , d7,  d15

.endm

this is using fixed-sized K-direction accumulation as well.
I mean, there are majorly 2 reasons why we have explicit kernel like:

#define KERNEL_8x16_ACC16()
...
#define KERNEL_8x16_ACC8()
...
#define KERNEL_8x16_ACC4()
...
#define KERNEL_8x16_ACC1()
...
  1. to regulate the proportion of __fp16 values of vfma, and float values in C accumulation. -> in this point, you are right. It can be regulated easier with your suggestion.
  2. but in the kernel from above, isn't it coding the K-direction accumulation with 4 times at once? Maybe we should leave certain amount of fixed-sized kernel. Clearly, we need to go over latency check for that.

And... not related to this topic, I don't really prefer following XNNPACK implementation. (though I told you previously to refer to it. sorry for that)
Their fp16 gemm benchmark is slower than fp32 gemm, and they are allocating memory for GEMM output way too big -> this makes it hard to refer to their way of thinking to the nntrainer.

Anyway, thanks for your kind review!

@skykongkong8
Copy link
Member Author

skykongkong8 commented Aug 7, 2024

@djeong20
Hmm... I've gone through FP16 GEMM unittest test, and had a little bit unhappy result.
In short, they have very subtle but clear difference...
smallest difference in terms of Gflops is less than 1%, but at biggest case 8%
Following is TC result for your information

Before
Try #1
[INFO] Latency : 4095.28490000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 93.32829372
[INFO] Latency : 6330.02291000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 95.88099516
[INFO] Latency : 9644.96927000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 93.93183520
[INFO] Latency : 14042.25833000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 91.86165485
[INFO] Latency : 18921.98801000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 93.51406412
[INFO] Latency : 26311.99010000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 89.50927783

Try #2
[INFO] Latency : 4032.19531000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 94.78855130
[INFO] Latency : 6488.52344000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 93.53883077
[INFO] Latency : 9809.80469000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 92.35348640
[INFO] Latency : 14213.26614000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 90.75641554
[INFO] Latency : 19063.69687000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 92.81893287
[INFO] Latency : 26564.81249000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 88.65740095

Try #3
[INFO] Latency : 4079.07395000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 93.69919660
[INFO] Latency : 6545.82968000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 92.71993401
[INFO] Latency : 9888.91458000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 91.61467183
[INFO] Latency : 14376.37395000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 89.72673447
[INFO] Latency : 19377.54166000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 91.31560809
[INFO] Latency : 26971.74739000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 87.31978681

Try #4
[INFO] Latency : 4142.87708000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 92.25616513
[INFO] Latency : 6629.20677000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 91.55377363
[INFO] Latency : 10061.58437000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 90.04244567
[INFO] Latency : 14656.08801000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 88.01428370
[INFO] Latency : 19676.60885000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 89.92769097
[INFO] Latency : 27258.33019000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 86.40174272

After
Try #1
[INFO] Latency : 4141.33907000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 92.29042721
[INFO] Latency : 6470.47916000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 93.79968330
[INFO] Latency : 9858.99531000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 91.89269652
[INFO] Latency : 14423.61198000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 89.43287505
[INFO] Latency : 19305.43697000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 91.65666660
[INFO] Latency : 26863.52916000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 87.67154970

Try #2
[INFO] Latency : 4083.25104000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 93.60334406
[INFO] Latency : 6561.86666000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 92.49332963
[INFO] Latency : 10031.08906000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 90.31618188
[INFO] Latency : 14521.63593000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 88.82918524
[INFO] Latency : 19558.44113000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 90.47101393
[INFO] Latency : 27245.39114000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 86.44277558

Try#3
[INFO] Latency : 4223.33125000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 90.49869152
[INFO] Latency : 6735.30781000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 90.11153063
[INFO] Latency : 10125.71198000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 89.47219374
[INFO] Latency : 14758.81614000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 87.40166391
[INFO] Latency : 19903.89791000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 88.90077753
[INFO] Latency : 27369.10885000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 86.05202474


Try#4
[INFO] Latency : 4289.84010000 ms | Dim: 576 | Error: 0.00140162 | Gflops: 89.09561734
[INFO] Latency : 6892.05364000 ms | Dim: 672 | Error: 0.00197730 | Gflops: 88.06212599
[INFO] Latency : 10414.10313000 ms | Dim: 768 | Error: 0.00310214 | Gflops: 86.99449705
[INFO] Latency : 14986.09218000 ms | Dim: 864 | Error: 0.00367473 | Gflops: 86.07614797
[INFO] Latency : 20215.67447000 ms | Dim: 960 | Error: 0.01159486 | Gflops: 87.52970388
[INFO] Latency : 27385.74374000 ms | Dim: 1056 | Error: 0.00553330 | Gflops: 85.99975426

@skykongkong8
Copy link
Member Author

skykongkong8 commented Aug 7, 2024

@djeong20
I tested using mixed version kernel that uses N as an input, but stack the total computation for 4 times only.
This showed almost the same performance to the previous version, but still reduces code lines a LOT!
I think this kernel should be a new standard of implementing ukernels of gemm. Thanks for pointing this issue out!

fyi) new ukernel:

#define KERNEL_8x16_ACC_N4(N)                            \
  do {                                                   \
    for (int i = 0; i < N; i += 4) {                     \
      va0 = vld1q_f16(a + 8 * i);                        \
      vb1 = vld1q_f16(b + 16 * i);                       \
      vb2 = vld1q_f16(b + 16 * i + 8 * 1);               \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
      va0 = vld1q_f16(a + 8 * i + 8 * 1);                \
      vb1 = vld1q_f16(b + 16 * i + 8 * 2);               \
      vb2 = vld1q_f16(b + 16 * i + 8 * 3);               \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
      va0 = vld1q_f16(a + 8 * i + 8 * 2);                \
      vb1 = vld1q_f16(b + 16 * i + 8 * 4);               \
      vb2 = vld1q_f16(b + 16 * i + 8 * 5);               \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
      va0 = vld1q_f16(a + 8 * i + 8 * 3);                \
      vb1 = vld1q_f16(b + 16 * i + 8 * 6);               \
      vb2 = vld1q_f16(b + 16 * i + 8 * 7);               \
      v0_7 = vfmaq_laneq_f16(v0_7, vb1, va0, 0);         \
      v8_15 = vfmaq_laneq_f16(v8_15, vb1, va0, 1);       \
      v16_23 = vfmaq_laneq_f16(v16_23, vb1, va0, 2);     \
      v24_31 = vfmaq_laneq_f16(v24_31, vb1, va0, 3);     \
      v32_39 = vfmaq_laneq_f16(v32_39, vb1, va0, 4);     \
      v40_47 = vfmaq_laneq_f16(v40_47, vb1, va0, 5);     \
      v48_55 = vfmaq_laneq_f16(v48_55, vb1, va0, 6);     \
      v56_63 = vfmaq_laneq_f16(v56_63, vb1, va0, 7);     \
      v64_71 = vfmaq_laneq_f16(v64_71, vb2, va0, 0);     \
      v72_79 = vfmaq_laneq_f16(v72_79, vb2, va0, 1);     \
      v80_87 = vfmaq_laneq_f16(v80_87, vb2, va0, 2);     \
      v88_95 = vfmaq_laneq_f16(v88_95, vb2, va0, 3);     \
      v96_103 = vfmaq_laneq_f16(v96_103, vb2, va0, 4);   \
      v104_111 = vfmaq_laneq_f16(v104_111, vb2, va0, 5); \
      v112_119 = vfmaq_laneq_f16(v112_119, vb2, va0, 6); \
      v120_127 = vfmaq_laneq_f16(v120_127, vb2, va0, 7); \
    }                                                    \
    l += N;                                              \
    __builtin_prefetch(b + 16 * N, 0, 3);                \
    __builtin_prefetch(a + 8 * N, 0, 3);                 \
    b += 16 * N;                                         \
    a += 8 * N;                                          \
  } while (0)

@skykongkong8
Copy link
Member Author

Think this conversation took too long.. I opened a new PR #2700 for this

Copy link
Contributor

@djeong20 djeong20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with #2700, LGTM!

Copy link
Collaborator

@jijoongmoon jijoongmoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jijoongmoon jijoongmoon merged commit 05f7b4a into nnstreamer:main Aug 9, 2024
44 checks passed
@skykongkong8 skykongkong8 deleted the pr/hgemm/experimental_kernel branch August 16, 2024 01:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants