Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototypical networks in Flair #2627

Merged
merged 26 commits into from
Feb 9, 2022
Merged

Prototypical networks in Flair #2627

merged 26 commits into from
Feb 9, 2022

Conversation

alanakbik
Copy link
Collaborator

Thanks to the work of @plonerma, this PR adds a prototype decoder to Flair.

Prototype networks learn prototypes for each target class. For each data point to be classified, the network predicts a vector in class-prototype-space, which is then compared to all class prototypes.The prediction is then the closest class prototype. See paper Prototypical Networks for Few-shot Learning for more info.

@plonerma implemented a custom decoder that can be added to any Flair model that inherits from DefaultClassifier (i.e. early all Flair models). For instance, use this script:

from flair.data import Corpus
from flair.datasets import UP_ENGLISH
from flair.embeddings import TransformerWordEmbeddings
from flair.models import WordTagger
from flair.nn import PrototypicalDecoder
from flair.trainers import ModelTrainer

# what tag do we want to predict?
tag_type = 'frame'

# get a corpus
corpus: Corpus = UP_ENGLISH().downsample(0.1)

# make the tag dictionary from the corpus
tag_dictionary = corpus.make_label_dictionary(label_type=tag_type)

# initialize simple embeddings
embeddings = TransformerWordEmbeddings(model="distilbert-base-uncased",
                                       fine_tune=True,
                                       layers='-1')

# initialize prototype decoder
decoder = PrototypicalDecoder(num_prototypes=len(tag_dictionary),
                              embeddings_size=embeddings.embedding_length,
                              distance_function='euclidean',
                              normal_distributed_initial_prototypes=True,
                              )

# initialize the WordTagger, but pass the prototype decoder
tagger = WordTagger(embeddings,
                    tag_dictionary,
                    tag_type,
                    decoder=decoder)

# initialize trainer
trainer = ModelTrainer(tagger, corpus)

# run training
trainer.fine_tune('resources/taggers/prototypical_decoder')

However, this is still in beta and will likely still change a lot until the next Flair release.

@alanakbik alanakbik merged commit 6b6b2f0 into master Feb 9, 2022
@alanakbik alanakbik deleted the prototype_decoder_refactor branch February 9, 2022 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants