Skip to content

Commit

Permalink
float16 support for GPU als model (#661)
Browse files Browse the repository at this point in the history
This adds support for using float16 factors in the GPU version of the ALS model. This reduces the memory needed for the ALS model embeddings by half - while providing a small increase to training time, and virtually no difference in the accuracy of the learned model.

All computations are still performed using float32 - including both training and inference. This is done with using mixed precision matrix multiplications during inference : the fp16 factors are multiplied together with results accumulated as fp32. During training, the factors are converted from fp16 to fp32 - and updates are calculated in 32-bit before being stored back as fp16.
  • Loading branch information
benfred committed May 30, 2023
1 parent 295a7c4 commit ec36f33
Show file tree
Hide file tree
Showing 17 changed files with 401 additions and 145 deletions.
3 changes: 2 additions & 1 deletion implicit/als.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def AlternatingLeastSquares(
alpha : float, optional
The weight to give to positive examples.
dtype : data-type, optional
Specifies whether to generate 64 bit or 32 bit floating point factors
Specifies whether to generate 64 bit or 32 bit or 16 bit floating point factors
use_native : bool, optional
Use native extensions to speed up model fitting
use_cg : bool, optional
Expand All @@ -61,6 +61,7 @@ def AlternatingLeastSquares(
factors,
regularization,
alpha,
dtype=dtype,
iterations=iterations,
calculate_training_loss=calculate_training_loss,
random_state=random_state,
Expand Down
11 changes: 11 additions & 0 deletions implicit/cpu/_als.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,16 @@ cdef inline void gesv(int * n, int * nrhs, floating * a, int * lda, int * piv, f
cython_lapack.sgesv(n, nrhs, a, lda, piv, b, ldb, info)


def _check_als_dtype(X):
_ALLOWED_DTYPES = (np.float32, np.float64)
if X.dtype not in _ALLOWED_DTYPES:
raise ValueError(f"Invalid dtype {X.dtype} for cpu ALS model. "
f"Allowed dtypes are: {_ALLOWED_DTYPES}")


def least_squares(Cui, X, Y, regularization, num_threads=0):
_check_als_dtype(X)
_check_als_dtype(Y)
YtY = np.dot(np.transpose(Y), Y)
_least_squares(YtY, Cui.indptr, Cui.indices, Cui.data.astype('float32'),
X, Y, regularization, num_threads)
Expand Down Expand Up @@ -132,6 +141,8 @@ def _least_squares(YtY, integral[:] indptr, integral[:] indices, float[:] data,


def least_squares_cg(Cui, X, Y, regularization, num_threads=0, cg_steps=3):
_check_als_dtype(X)
_check_als_dtype(Y)
return _least_squares_cg(Cui.indptr, Cui.indices, Cui.data.astype('float32'),
X, Y, regularization, num_threads, cg_steps)

Expand Down
1 change: 1 addition & 0 deletions implicit/cpu/als.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def to_gpu(self):
factors=self.factors,
regularization=self.regularization,
alpha=self.alpha,
dtype=self.dtype,
iterations=self.iterations,
calculate_training_loss=self.calculate_training_loss,
random_state=self.random_state,
Expand Down
72 changes: 52 additions & 20 deletions implicit/gpu/_cuda.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import numpy as np
from cython.operator import dereference

from cython cimport view
from libc.stdint cimport uint16_t
from libcpp cimport bool
from libcpp.utility cimport move, pair

Expand All @@ -15,7 +16,6 @@ from .matrix cimport COOMatrix as CppCOOMatrix
from .matrix cimport CSRMatrix as CppCSRMatrix
from .matrix cimport Matrix as CppMatrix
from .matrix cimport Vector as CppVector
from .matrix cimport calculate_norms as cpp_calculate_norms
from .random cimport RandomState as CppRandomState
from .utils cimport get_device_count as cpp_get_device_count

Expand Down Expand Up @@ -49,38 +49,37 @@ cdef class KnnQuery(object):
def __dealloc__(self):
del self.c_knn

@cython.boundscheck(False)
def topk(self, Matrix items, Matrix m, int k, Matrix item_norms=None,
COOMatrix query_filter=None, IntVector item_filter=None):
cdef CppMatrix * queries = m.c_matrix
cdef CppCOOMatrix * c_query_filter = NULL
cdef CppVector[int] * c_item_filter = NULL
cdef size_t rows = queries.rows
cdef int[:, :] x
cdef float[:, :] y
cdef int[:, :] indices_view
cdef float[:, :] distances_view

cdef float * c_item_norms = NULL
cdef CppMatrix * c_item_norms = NULL
if item_norms is not None:
c_item_norms = item_norms.c_matrix.data
c_item_norms = item_norms.c_matrix

if query_filter is not None:
c_query_filter = query_filter.c_matrix

if item_filter is not None:
c_item_filter = item_filter.c_vector


indices = np.zeros((rows, k), dtype="int32")
distances = np.zeros((rows, k), dtype="float32")
x = indices
y = distances
indices_view = indices
distances_view = distances

with nogil:
self.c_knn.topk(dereference(items.c_matrix), dereference(queries), k,
&x[0, 0], &y[0, 0], c_item_norms, c_query_filter, c_item_filter)
&indices_view[0, 0], &distances_view[0, 0], c_item_norms, c_query_filter, c_item_filter)

return indices, distances


cdef class Matrix(object):
cdef CppMatrix * c_matrix

Expand All @@ -89,24 +88,33 @@ cdef class Matrix(object):
self.c_matrix = NULL
return

cdef float[:, :] c_X
cdef float[:, :] temp_float
cdef uint16_t[:, :] temp_half
cdef long data

# see if the input support CAI (cupy/pytorch/cudf etc)
cai = getattr(X, "__cuda_array_interface__", None)
if cai:
shape = cai["shape"]
data = cai["data"][0]
self.c_matrix = new CppMatrix(shape[0], shape[1], <float*>data, False)
itemsize = int(cai["typestr"][2])
self.c_matrix = new CppMatrix(shape[0], shape[1], <void*>data, False, itemsize)
else:
# otherwise assume we're a buffer on host
c_X = X
self.c_matrix = new CppMatrix(X.shape[0], X.shape[1], &c_X[0, 0], True)

if X.dtype.char == "f":
temp_float = X
self.c_matrix = new CppMatrix(X.shape[0], X.shape[1], &temp_float[0, 0], True, 4)
elif X.dtype.char == "e":
temp_half = X.view(np.uint16)
self.c_matrix = new CppMatrix(X.shape[0], X.shape[1], &temp_half[0, 0], True, 2)
else:
raise ValueError(f"unhandled dtype for GPU Matrix {X.dtype}")

@classmethod
def zeros(cls, rows, cols):
ret = Matrix(None)
ret.c_matrix = new CppMatrix(rows, cols, NULL, True)
ret.c_matrix = new CppMatrix(rows, cols, NULL, True, 4)
return ret

@property
Expand Down Expand Up @@ -154,14 +162,37 @@ cdef class Matrix(object):
rows = IntVector(np.array(rowids).astype("int32"))
self.c_matrix.assign_rows(dereference(rows.c_vector), dereference(other.c_matrix))

def astype(self, dtype):
dtype = np.dtype(dtype)
_ALLOWED_DTYPES = (np.float16, np.float32)
if dtype not in _ALLOWED_DTYPES:
raise ValueError(f"Invalid dtype '{dtype}' for GPU model. "
f"Allowed dtypes are: {_ALLOWED_DTYPES}")

cdef int itemsize = dtype.itemsize
ret = Matrix(None)
ret.c_matrix = new CppMatrix(self.c_matrix.astype(itemsize))
return ret

def resize(self, size_t rows, size_t cols):
self.c_matrix.resize(rows, cols)

def to_numpy(self):
ret = np.zeros((self.c_matrix.rows, self.c_matrix.cols), dtype="float32")
cdef float[:, :] temp = ret
self.c_matrix.to_host(&temp[0, 0])
return ret
cdef float[:, :] temp_float
cdef uint16_t[:, :] temp_half

if self.c_matrix.itemsize == 4:
ret = np.zeros((self.c_matrix.rows, self.c_matrix.cols), dtype="float32")
temp_float = ret
self.c_matrix.to_host(&temp_float[0, 0])
return ret
elif self.c_matrix.itemsize == 2:
ret = np.zeros((self.c_matrix.rows, self.c_matrix.cols), dtype="float16")
temp_half = ret.view(np.uint16)
self.c_matrix.to_host(&temp_half[0, 0])
return ret
else:
raise ValueError(f"Invalid itemsize {self.c_matrix.itemsize}")

def __repr__(self):
return f"Matrix({str(self.to_numpy())})"
Expand All @@ -174,6 +205,7 @@ cdef class Matrix(object):
del self.c_matrix



cdef class IntVector(object):
cdef CppVector[int] * c_vector

Expand Down Expand Up @@ -241,7 +273,7 @@ cdef class LeastSquaresSolver(object):

def calculate_norms(Matrix items):
ret = Matrix(None)
ret.c_matrix = new CppMatrix(cpp_calculate_norms(dereference(items.c_matrix)))
ret.c_matrix = new CppMatrix(items.c_matrix.calculate_norms())
return ret


Expand Down
Loading

0 comments on commit ec36f33

Please sign in to comment.