Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented sampled softmax for NextItemPredictionTask #671

Merged
merged 8 commits into from
Apr 17, 2023

Conversation

gabrielspmoreira
Copy link
Member

@gabrielspmoreira gabrielspmoreira commented Apr 7, 2023

Goals ⚽

Implements sampled softmax for NextItemPredictionTask. It allows for faster training and evaluation.

Implementation Details 🚧

  • Refactored NextItemPredictionTask to have a standard output layer op (a dot product) when weight_tying is both enabled or not.
  • Added sampled_softmax option to NextItemPredictionTask
tr.NextItemPredictionTask(weight_tying=True, sampled_softmax=True, max_n_samples=1000)
  • Implemented a LogUniformSampler that is able to return probabilities for both unique_sampling = True or False
  • Implemented a generic logQ correction for NextItemPredictionTask
  • Changed LabelSmoothCrossEntropyLoss to be just an alias of torch.nn.CrossEntropyLoss(label_smoothing=...), as PyTorch has added label_smoothing in one of its last versions. Added a DeprecationWarning to LabelSmoothCrossEntropyLoss

Testing Details 🔍

  • Created a test to check and demonstrate the usage of sampled softmax: test_with_next_item_pred_sampled_softmax

Benchmark 🔍

I have performed a benchmark of sampled softmax in different configurations (weight tying enabled and disabled and with different # samples) to understand the impact of sampled softmax in training throughtput and accuracy.

Setup

The experiments were performed using the T4Rec paper reproducibility script, which was changed to accept new CLI args --sampled_softmax and --sampled_softmax_max_n_samples, and the REES46 preprocessed dataset.

The benchmark was done using Merlin PyTorch 23.02 container, with manual update of the core, dataloader and models folders to pull and install their latest version from GitHub.

Command line
The script performs incremental training and evaluation. I use the first five days for training and evaluation is computed for each next day. Here is the base command line with the utilized hparams.
The hparams that are changed for the experiments are --mf_constrained_embeddings (enables weight-tying if provided, i.e., reusing the item id embedding table as output layer), --sampled_softmax (enables sampled softmax if provided) and --sampled_softmax_max_n_samples (number of negative samples).

cd /transformers4rec/examples/t4rec_paper_experiments/t4r_paper_repro
CUDA_VISIBLE_DEVICES=0 python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train --do_eval --validate_every 10 --logging_steps 20 --save_steps 0 --data_path $DATA_PATH --features_schema_path "../datasets_configs/ecom_rees46/rees46_schema.pbtxt" --fp16 --data_loader_engine merlin --start_time_window_index 1 --final_time_window_index 6 --time_window_folder_pad_digits 4 --model_type albert --loss_type cross_entropy --per_device_eval_batch_size 128 --similarity_type concat_mlp --tf_out_activation tanh --inp_merge mlp --learning_rate_warmup_steps 0 --learning_rate_schedule linear_with_warmup --hidden_act gelu --num_train_epochs 5 --dataloader_drop_last --compute_metrics_each_n_steps 1 --session_seq_length_max 20 --eval_on_last_item_seq_only  --layer_norm_featurewise --mlm --num_hidden_groups -1 --inner_group_num 1 --per_device_train_batch_size 512 --learning_rate 0.0004904752786458524 --dropout 0.0 --input_dropout 0.1 --weight_decay 9.565968888623912e-05 --d_model 320 --item_embedding_dim 320 --n_layer 2 --n_head 8  --stochastic_shared_embeddings_replacement_prob 0.06 --item_id_embeddings_init_std 0.11 --other_embeddings_init_std 0.025 --mlm_probability 0.6000000000000001 --eval_on_test_set --seed 100 --report_to none --label_smoothing 0.2 --mf_constrained_embeddings --sampled_softmax --sampled_softmax_max_n_samples 1000

Results

The results can be seen in the following table. Steps/sec represents the throughtput and Recall and NDCG are accuracy top-k metrics.

bechmark_sampled_softmax

The gist is that it is possible to get both a better training throughput with a gain of accuracy by using sampled softmax.

Some notes from this results:

  • Throughput (steps/sec) always increase with sampled sofmax and smaller number of examples, as expected.
  • The best accuracies was obtained with sampled softmax for both weight-tying False and True. But the best overall accuracy obtained with weight tying=True, sampled softmax and logQ correction
  • There are specific lines in the results table where we report results without logQ correction proposed for sampled softmax. It is noticeable that without it the sampled softmax underperforms in terms of accuracy, as it overpenalizes popular items, that are sampled more often as negatives.
  • A side note is that weight-tying typically provides better accuracy when enabled, as we previously observed reported in our RecSys competition papers and in T4Rec paper.

Disclaimer: These experiments were not hypertuned for every configuration. Furthermore, accuracy results might differ a lot with different runs in particular when smaller number of samples (e.g. 1k) are used.

@gabrielspmoreira gabrielspmoreira self-assigned this Apr 7, 2023
@gabrielspmoreira gabrielspmoreira added the enhancement New feature or request label Apr 7, 2023
@gabrielspmoreira gabrielspmoreira added this to the Merlin 23.04 milestone Apr 7, 2023
@github-actions
Copy link

github-actions bot commented Apr 9, 2023

Copy link
Contributor

@sararb sararb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good to me. I just left some remarks/questions to understand the code base.

else:
logits = self.output_layer(inputs)
logits = inputs @ output_weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to keep the bias parameter self.output_layer_bias: logits = inputs @ output_weights + self.output_layer_bias?

Copy link
Member Author

@gabrielspmoreira gabrielspmoreira Apr 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the bias because it think would not be available if ANN is used later for serving. Does that make sense?
I can do some benchmark later to see if bias helps to improve accuracy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, that makes sense! Otherwise, we'll need to save the output_bias vector in addition to the pre-trained candidate embeddings..


return predictions
logits = torch.cat([positive_scores, negative_scores], axis=1)
new_targets = torch.zeros(logits.shape[0], dtype=torch.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first element of each row should be 1 instead of 0 to account for the positive target, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The targets are the ids sparse representation, not the one-hot representation. Does that make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, so the new_targets is a 1-D vector that contains the index of the positive item in the logits tensor (which is always corresponding to index 0)

def forward(self, inputs):
return self.module(inputs)
def forward(self, inputs, **kwargs):
return self.module(inputs, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need the extra **kwargs here?

" [`sum`, `none`, `mean`]"
)
return loss
return torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction=reduction, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see that label_smoothing was added in the latest version of CrossEntropyLoss!

y = labels_all
x, y = self.pre(x, targets=y, training=training, testing=testing) # type: ignore

loss = self.loss(x, y)
return {
"loss": loss,
"labels": labels_all,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that self.pre invokes the next-item task using the sampled softmax option, which returns logits x related to the list of [positive_item, sampled negatives]. So I wonder how these logits are connected to labels_all (which is a tensor of positive item ids) for metrics calculation.

dist = self.unique_sampling_dist
else:
dist = self.dist
dist = dist.to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to move the definition of dist to the class constructor, to avoid copying the tensor to the GPU/CPU device multiple times?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge is how to get the device in the constructor. Any ideas?

Copy link
Contributor

@sararb sararb Apr 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use register_buffer to register the variable dist. Then, the method model.to(device) will ensure that the buffer is copied to the right device. It is something like:

  • in the constructor, you set: self.register_buffer('dist', dist)
  • in the method sampled: you can just call the registered buffer self.dist

so we use `torch.multinomial(..., replacement=True).unique()` which doesn't guarantee
the same number of unique sampled items. You can try to increase
n_samples_multiplier_before_unique to increase the chances to have more
unique samples in that case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 !! Thank you for creating this class. It was very helpful for learning how to approximate item frequency distributions for both sampling with and without repetition!

… adjusted min const value used to fix sampling accidental hits to work properly with fp16. Ensures targets are torch.long, otherwise losses raise an error. Turning metrics top_ks as lists rather than tensors
…stributions as a buffer, so that they are automatically assigned to the right device and also serialized correctly
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Task] Add a softmax sampling
2 participants