Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactory/fix of sampled softmax to add logQ correction #1051

Merged
merged 5 commits into from
Apr 18, 2023

Conversation

gabrielspmoreira
Copy link
Member

@gabrielspmoreira gabrielspmoreira commented Apr 7, 2023

Goals ⚽

Sampled softmax is a popular technique to deal with multi-class classification with a very large number of classes.
It has been used to train retrieval models with contrastive learning by using a subset of candidate negative items during training, instead of using all other items as negatives.
This PR adds the important logQ sampling correction proposed by sampled softmax, in order to better approximate it to the full softmax.

Implementation Details 🚧

In general, sampled softmax was implemented before by using ContrastiveOutput as model output and the PopularityBasedSamplerV2 as a sampler (like example below), that uses the log-uniform distribution to approximate the long-tail items frequency (assuming that categorical ids are sorted decreasingly by frequency).

mm.ContrastiveOutput(
            schema["item_id"],
            negative_samplers=PopularityBasedSamplerV2(max_id=10000, max_num_samples=100),
        ),

However, the logQ correction was not implemented as proposed by sampled softmax, to fix the overpenalization of popular items as they are sampled more often as negatives. The logQ correction can be used by PopularityLogitsCorrection, but it requires providing the items frequency distribution. As our sampled softmax implemetation (PopularityBasedSamplerV2 ) uses an approximated log-uniform distribution for sampling, I implemented the corresponding sampling probability of the positive and negative items to allow for the sampling correction (with replacement or not). You can use sampled softmax as in the folllowing example.

mm.ContrastiveOutput(
            schema["item_id"],
            negative_samplers=PopularityBasedSamplerV2(max_id=10000, max_num_samples=100),
            logq_sampling_correction=True,
        ),

It was created a new logq_sampling_correction=False arg to the ContrastiveOutput, that should be set to True when the sampler supports returning the items' sampling probs (like PopularityBasedSamplerV2 does). If it is used, then PopularityLogitsCorrection doesn't need to be used.

Summary of main API changes

  • Changed Candidate to optionally store the sampling prob. of each item
  • Changed the CandidateSampler abstract class to have a with_sampling_probs(items) method, that allows returning the probability of the provided items according to the sampler distribution.
  • Changed the PopularityBasedSamplerV2 to support both unique and not-unique samples, to compute its distribution in the constructor (get_sampling_distribution()) as it is based on an log-uniform approximation of items' long-tail frequency distribution
  • Changed ContrastiveOutput to set the sampling probs for both positive and negative stores. Created the arg logq_sampling_correction, that if enabled subtracts the logQ sampled probs based on the Candidate.sampling_prob

Testing Details 🔍

  • Added the test_contrastive_output_with_sampled_softmax to test and showcase how sampled softmax can be used in Merlin Models

@gabrielspmoreira gabrielspmoreira self-assigned this Apr 7, 2023
@gabrielspmoreira gabrielspmoreira added this to the Merlin 23.04 milestone Apr 7, 2023
@gabrielspmoreira gabrielspmoreira added bug Something isn't working enhancement New feature or request labels Apr 7, 2023
@github-actions
Copy link

github-actions bot commented Apr 7, 2023

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1051

@marcromeyn marcromeyn self-requested a review April 18, 2023 07:02
@@ -132,6 +147,7 @@ def __init__(
query_name: str = "query",
candidate_name: str = "candidate",
store_negative_ids: bool = False,
logq_sampling_correction: Optional[bool] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind making this false by default? Are there down-sides of doing LogQ correction?

@gabrielspmoreira gabrielspmoreira merged commit 2a20547 into main Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants