Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embeddings: fix attention mask for special Transformer architectures #2485

Merged
merged 1 commit into from
Oct 29, 2021

Conversation

stefan-it
Copy link
Member

Hi,

special architectures such as the recently introduced FNet do not provide an attention_mask for the model forward pass.

This would cause the following error, that can be reproduced with the official NER example script:

$ cd examples/ner
$ python3 run_ner.py --dataset_name WNUT_17 --model_name_or_path google/fnet-base --output_dir wnut-fnet-base-baseline --num_epochs 1 

Outputs:

Traceback (most recent call last):                                                                                                              
  File "run_ner.py", line 186, in <module>                                                                                                      
    main()                                                                                                                                      
  File "run_ner.py", line 169, in main                                                                                                          
    trainer.fine_tune(data_args.output_dir,                                                                                                     
  File "/mnt/flair-fnet-fix/flair/trainers/trainer.py", line 808, in fine_tune                                                                  
    return self.train(                                                                                                                                                                                                                                                                          
  File "/mnt/flair-fnet-fix/flair/trainers/trainer.py", line 467, in train                                                                      
    loss = self.model.forward_loss(batch_step)                                                                                                                                                                                                                                                  
  File "/mnt/flair-fnet-fix/flair/models/sequence_tagger_model.py", line 394, in forward_loss                                                   
    features = self.forward(data_points)                                                                                                                                                                                                                                                        
  File "/mnt/flair-fnet-fix/flair/models/sequence_tagger_model.py", line 399, in forward                                                        
    self.embeddings.embed(sentences)                                                                                                            
  File "/mnt/flair-fnet-fix/flair/embeddings/base.py", line 60, in embed                                                                        
    self._add_embeddings_internal(sentences)                                                                                                                                                                                                                                                    
  File "/mnt/flair-fnet-fix/flair/embeddings/token.py", line 1015, in _add_embeddings_internal                                                  
    attention_mask = batch_encoding['attention_mask'].to(flair.device)                                                                          
  File "/opt/conda/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 237, in __getitem__                               
    return self.data[item]                                                                                                                      
KeyError: 'attention_mask'

This PR fixes this, so that FNet - and other potential architectures that do not have an attention_mask - can be used in Flair.

Notice: I did also compare the model fine-tuning result before and after the PR with a DistilBERT model, to make sure that this doesn't introduce any regression. Tested with:

$ python3 run_ner.py --dataset_name WNUT_17 --model_name_or_path google/fnet-base --output_dir wnut-fnet-base-baseline --num_epochs 1

@alanakbik
Copy link
Collaborator

@stefan-it thanks for fixing this! Any good results with FNET? ;)

@alanakbik alanakbik merged commit a1bee91 into master Oct 29, 2021
@alanakbik alanakbik deleted the fnet-attention-mask-fix branch October 29, 2021 20:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants