Refactory/fix of sampled softmax to add logQ correction #1051
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Goals ⚽
Sampled softmax is a popular technique to deal with multi-class classification with a very large number of classes.
It has been used to train retrieval models with contrastive learning by using a subset of candidate negative items during training, instead of using all other items as negatives.
This PR adds the important logQ sampling correction proposed by sampled softmax, in order to better approximate it to the full softmax.
Implementation Details 🚧
In general, sampled softmax was implemented before by using
ContrastiveOutput
as model output and thePopularityBasedSamplerV2
as a sampler (like example below), that uses the log-uniform distribution to approximate the long-tail items frequency (assuming that categorical ids are sorted decreasingly by frequency).However, the logQ correction was not implemented as proposed by sampled softmax, to fix the overpenalization of popular items as they are sampled more often as negatives. The logQ correction can be used by
PopularityLogitsCorrection
, but it requires providing the items frequency distribution. As our sampled softmax implemetation (PopularityBasedSamplerV2
) uses an approximated log-uniform distribution for sampling, I implemented the corresponding sampling probability of the positive and negative items to allow for the sampling correction (with replacement or not). You can use sampled softmax as in the folllowing example.It was created a new
logq_sampling_correction=False
arg to theContrastiveOutput
, that should be set to True when the sampler supports returning the items' sampling probs (likePopularityBasedSamplerV2
does). If it is used, thenPopularityLogitsCorrection
doesn't need to be used.Summary of main API changes
Candidate
to optionally store the sampling prob. of each itemCandidateSampler
abstract class to have awith_sampling_probs(items)
method, that allows returning the probability of the provided items according to the sampler distribution.PopularityBasedSamplerV2
to support both unique and not-unique samples, to compute its distribution in the constructor (get_sampling_distribution()) as it is based on an log-uniform approximation of items' long-tail frequency distributionContrastiveOutput
to set the sampling probs for both positive and negative stores. Created the arglogq_sampling_correction
, that if enabled subtracts the logQ sampled probs based on theCandidate.sampling_prob
Testing Details 🔍
test_contrastive_output_with_sampled_softmax
to test and showcase how sampled softmax can be used in Merlin Models