Skip to content

Commit

Permalink
Clear cached properties on partial_fit_* (#685)
Browse files Browse the repository at this point in the history
The item_norms/user_norms were incorrect after calling the partial_fit methods, especially for new items. Fix.
  • Loading branch information
benfred authored Aug 23, 2023
1 parent 00623e6 commit fd351da
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
8 changes: 8 additions & 0 deletions implicit/cpu/als.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def partial_fit_users(self, userids, user_items):
# update the stored factors with the newly calculated values
self.user_factors[userids] = user_factors

# clear any cached properties that are invalidated by this update
self._user_norms = None
self._XtX = None

def partial_fit_items(self, itemids, item_users):
"""Incrementally updates item factors
Expand Down Expand Up @@ -339,6 +343,10 @@ def partial_fit_items(self, itemids, item_users):
# update the stored factors with the newly calculated values
self.item_factors[itemids] = item_factors

# clear any cached properties that are invalidated by this update
self._item_norms = None
self._YtY = None

def explain(self, userid, user_items, itemid, user_weights=None, N=10):
"""Provides explanations for why the item is liked by the user.
Expand Down
8 changes: 8 additions & 0 deletions implicit/gpu/als.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def partial_fit_users(self, userids, user_items):

self.user_factors.assign_rows(userids, user_factors)

# clear any cached properties that are invalidated by this update
self._user_norms = self._user_norms_host = None
self._XtX = None

def partial_fit_items(self, itemids, item_users):
"""Incrementally updates item factors
Expand Down Expand Up @@ -266,6 +270,10 @@ def partial_fit_items(self, itemids, item_users):
# update the stored factors with the newly calculated values
self.item_factors.assign_rows(itemids, item_factors)

# clear any cached properties that are invalidated by this update
self._item_norms = self._item_norms_host = None
self._YtY = None

@property
def solver(self):
if self._solver is None:
Expand Down

0 comments on commit fd351da

Please sign in to comment.