From 6a5d464b2a36f96e6f92ca8f4815cc71f1a0cff1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 30 Apr 2024 14:59:31 -0700 Subject: [PATCH 1/2] Update the quantization op and the dequantization kernel --- mlx/backend/metal/kernels/quantized.metal | 26 +++++++---------------- mlx/ops.cpp | 24 ++++++++++++++------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 72ef5f103..58f165faa 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -205,13 +205,10 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { } else if (bits == 4) { - const thread uint16_t* ws = (const thread uint16_t*)w; - U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f}; - for (int i = 0; i < (values_per_thread / 4); i++) { - result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias); - result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias); - result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias); - result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias); + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } } @@ -244,17 +241,10 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - U s[4] = { - scale, - scale / static_cast(16.0f), - scale / static_cast(256.0f), - scale / static_cast(4096.0f)}; - for (int i = 0; i < (N / 4); i++) { - w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias; - w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias; - w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias; - w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias; + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8c56601e2..3ad012932 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3275,7 +3275,9 @@ std::tuple quantize( } // Compute some constants used for the quantization - int n_bins = (1 << bits) - 1; // 2**bits - 1 + array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1 + array eps(1e-7, w.dtype()); + array zero(0, w.dtype()); int el_per_int = 32 / bits; array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); shifts = reshape(shifts, {1, 1, -1}, s); @@ -3299,16 +3301,22 @@ std::tuple quantize( reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array scales = maximum( - divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s), - array(1e-7, w.dtype()), - s); - // making sure that 0 is represented exactly in the resulting quantization - array biases = multiply(round(divide(w_min, scales, s), s), scales, s); + + array mask = greater(abs(w_min, s), abs(w_max, s), s); + array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + array biases = where(equal(q0, zero, s), zero, edge); // Quantize and pack w packed_w = astype( - round(divide(subtract(packed_w, biases, s), scales, s), s), uint32); + clip( + round(divide(subtract(packed_w, biases, s), scales, s), s), + zero, + n_bins), + uint32); packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s); packed_w = sum( multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); From 9a92f6fb235a500b9e3576bbc87557db706b4c3d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 30 Apr 2024 17:12:53 -0700 Subject: [PATCH 2/2] Update the test again to match the new quantization --- python/tests/test_quantized.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 32026c321..2c214abbd 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -16,7 +16,7 @@ def test_quantize_dequantize(self): w_hat = mx.dequantize(w_q, scales, biases, gs, b) errors = (w - w_hat).abs().reshape(*scales.shape, -1) eps = 1e-6 - self.assertTrue((2 * errors <= (scales[..., None] + eps)).all()) + self.assertTrue((errors <= (scales[..., None] + eps).abs()).all()) # test quantize/dequantize 0s a = mx.zeros((256, 512))