Skip to content

Commit

Permalink
q5_1 support
Browse files Browse the repository at this point in the history
  • Loading branch information
syx committed Mar 31, 2024
1 parent 0c66508 commit 7a04b46
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5345,6 +5345,14 @@ static void dequantize_mul_mat_batch_q5_0_cuda_sparse(const void * vx, const dfl
dequantize_mul_mat_batch_sparse<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);
}
static void dequantize_mul_mat_batch_q5_1_cuda_sparse(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_ncols, int dst_ne0, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_batch_sparse<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);
}
static void dequantize_mul_mat_batch_q8_0_cuda_sparse(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_ncols, int dst_ne0, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
Expand Down Expand Up @@ -5506,6 +5514,14 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream, const int *lst, const float *idx) {
GGML_ASSERT(ncols % QK5_1 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream, const int *lst, const float *idx) {
GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
Expand Down Expand Up @@ -5678,6 +5694,15 @@ static void dequantize_axpy_sparse_vec_q5_0_cuda(const void * vx, const dfloat *
dequantize_mul_mat_axpy_sparse<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}
static void dequantize_axpy_sparse_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}



static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
Expand All @@ -5704,6 +5729,14 @@ static void dequantize_axpy_sparse_batch_q5_0_cuda(const void * vx, const dfloat
dequantize_mul_mat_axpy_sparse_batch<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void dequantize_axpy_sparse_batch_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse_batch<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void dequantize_axpy_sparse_batch_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
Expand Down Expand Up @@ -7054,6 +7087,9 @@ inline void ggml_cuda_op_mul_mat_batch_sparse(
case GGML_TYPE_Q5_0:
dequantize_mul_mat_batch_q5_0_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_Q5_1:
dequantize_mul_mat_batch_q5_1_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_batch_q8_0_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
Expand Down Expand Up @@ -7317,6 +7353,12 @@ inline void ggml_cuda_op_mul_mat_vec_sparse_q(
row_lookup, sparse_idx
);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_cuda(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream,
row_lookup, sparse_idx
);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream,
Expand Down Expand Up @@ -7705,6 +7747,15 @@ inline void ggml_cuda_op_dequantize_axpy(
dequantize_axpy_sparse_batch_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_Q5_1:
if (sparse_idx == NULL) {
GGML_ASSERT(false);
} else if (ne11 == 1) {
dequantize_axpy_sparse_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, row_lookup, sparse_idx);
} else {
dequantize_axpy_sparse_batch_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_Q8_0:
if (sparse_idx == NULL) {
GGML_ASSERT(false);
Expand Down Expand Up @@ -8876,6 +8927,7 @@ static void ggml_cuda_mul_mat_sparse(const ggml_tensor * src0, const ggml_tensor
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_sparse_q, true);
// ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
Expand Down

0 comments on commit 7a04b46

Please sign in to comment.