Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Always re-initialize the unknown token vector for npz distributed on …
Browse files Browse the repository at this point in the history
…S3 (#228)
  • Loading branch information
leezu authored Jul 24, 2018
1 parent 369d6f7 commit 55cee6a
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions gluonnlp/embedding/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,19 +341,13 @@ def _load_embedding_serialized(self, pretrained_file_path):
# is the same now as it was when the .npz was generated. Under this
# assumption we can safely overwrite the respective token and
# vector from the npz.
if deserialized_embedding.unknown_token == self.unknown_token:
# If the unknown_token is the same, we will find it below and a
# new unknown token wont be inserted.
idx_to_token = deserialized_embedding.idx_to_token
idx_to_vec = deserialized_embedding.idx_to_vec
elif self.unknown_token:
# If they are different, we need to manually replace it so that
# it is found below and no new unknown token would be inserted.
if deserialized_embedding.unknown_token:
idx_to_token = deserialized_embedding.idx_to_token
idx_to_vec = deserialized_embedding.idx_to_vec
idx_to_token[C.UNK_IDX] = self.unknown_token
vec_len = idx_to_vec.shape[1]
idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
if self._init_unknown_vec:
vec_len = idx_to_vec.shape[1]
idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
else:
# If the TokenEmbedding shall not have an unknown token, we
# just delete the one in the npz.
Expand Down

0 comments on commit 55cee6a

Please sign in to comment.