Skip to content

Commit

Permalink
Add the pretrained charlm to the lemmatizer. (Currently just the forw…
Browse files Browse the repository at this point in the history
…ard one.)

The model uses the charlm as an additional embedding in the underlying seq2seq model

This connects the model to the pipeline and the training process.  The training script and the resources generation need to be updated as well
  • Loading branch information
AngledLuffa committed Aug 8, 2023
1 parent 88a0499 commit 58a29e8
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 30 deletions.
43 changes: 40 additions & 3 deletions stanza/models/common/char_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def build_charlm_vocab(path, cutoff=0):
vocab = CharVocab(data) # skip cutoff argument because this has been dealt with
return vocab

CHARLM_START = "\n"
CHARLM_END = " "

class CharacterLanguageModel(nn.Module):

def __init__(self, args, vocab, pad=False, is_forward_lm=True):
Expand Down Expand Up @@ -162,13 +165,25 @@ def get_representation(self, chars, charoffsets, charlens, char_orig_idx):
res = pad_packed_sequence(res, batch_first=True)[0]
return res

def per_char_representation(self, words):
device = next(self.parameters()).device
vocab = self.char_vocab()

all_data = [(vocab.map(word), len(word), idx) for idx, word in enumerate(words)]
all_data.sort(key=itemgetter(1), reverse=True)
chars = [x[0] for x in all_data]
char_lens = [x[1] for x in all_data]
char_tensor = get_long_tensor(chars, len(chars), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
with torch.no_grad():
output, _, _ = self.forward(char_tensor, char_lens)
output = [x[:y, :] for x, y in zip(output, char_lens)]
output = unsort(output, [x[2] for x in all_data])
return output

def build_char_representation(self, sentences):
"""
Return values from this charlm for a list of list of words
"""
CHARLM_START = "\n"
CHARLM_END = " "

forward = self.is_forward_lm
vocab = self.char_vocab()
device = next(self.parameters()).device
Expand All @@ -191,6 +206,7 @@ def build_char_representation(self, sentences):

all_data.sort(key=itemgetter(2), reverse=True)
chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))
# TODO: can this be faster?
chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)

with torch.no_grad():
Expand Down Expand Up @@ -250,6 +266,27 @@ def load(cls, filename, finetune=False):
return cls.from_full_state(state, finetune)
return cls.from_full_state(state['model'], finetune)

class CharacterLanguageModelWordAdapter(nn.Module):
"""
Adapts a character model to return embeddings for each character in a word
TODO: multiple charlms, eg, forward & back
"""
def __init__(self, charlm):
super().__init__()
self.charlm = charlm

def forward(self, words):
words = [CHARLM_START + x + CHARLM_END for x in words]
rep = self.charlm.per_char_representation(words)
padded_rep = torch.zeros(len(rep), max(x.shape[0] for x in rep), rep[0].shape[1], dtype=rep[0].dtype, device=rep[0].device)
for idx, row in enumerate(rep):
padded_rep[idx, :row.shape[0], :] = row
return padded_rep

def hidden_dim(self):
return self.charlm.hidden_dim()

class CharacterLanguageModelTrainer():
def __init__(self, model, params, optimizer, criterion, scheduler, epoch=1, global_step=0):
self.model = model
Expand Down
30 changes: 20 additions & 10 deletions stanza/models/common/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Seq2SeqModel(nn.Module):
"""
A complete encoder-decoder model, with optional attention.
"""
def __init__(self, args, emb_matrix=None):
def __init__(self, args, emb_matrix=None, contextual_embedding=None):
super().__init__()
self.vocab_size = args['vocab_size']
self.emb_dim = args['emb_dim']
Expand All @@ -32,6 +32,7 @@ def __init__(self, args, emb_matrix=None):
self.top = args.get('top', 1e10)
self.args = args
self.emb_matrix = emb_matrix
self.contextual_embedding = contextual_embedding

logger.debug("Building an attentional Seq2Seq model...")
logger.debug("Using a Bi-LSTM encoder")
Expand All @@ -50,7 +51,10 @@ def __init__(self, args, emb_matrix=None):
self.emb_drop = nn.Dropout(self.emb_dropout)
self.drop = nn.Dropout(self.dropout)
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
self.encoder = nn.LSTM(self.emb_dim, self.enc_hidden_dim, self.nlayers, \
self.input_dim = self.emb_dim
if self.contextual_embedding is not None:
self.input_dim += self.contextual_embedding.hidden_dim()
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
batch_first=True, attn_type=self.args['attn_type'])
Expand Down Expand Up @@ -158,7 +162,7 @@ def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None):

return log_probs, dec_hidden

def embed(self, src, src_mask, pos):
def embed(self, src, src_mask, pos, raw):
enc_inputs = self.emb_drop(self.embedding(src))
batch_size = enc_inputs.size(0)
if self.use_pos:
Expand All @@ -167,12 +171,18 @@ def embed(self, src, src_mask, pos):
enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
pos_src_mask = src_mask.new_zeros([batch_size, 1])
src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
if raw is not None and self.contextual_embedding is not None:
raw_inputs = self.contextual_embedding(raw)
if self.use_pos:
raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
return enc_inputs, batch_size, src_lens, src_mask

def forward(self, src, src_mask, tgt_in, pos=None):
def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
# prepare for encoder/decoder
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos)
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)

# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
Expand All @@ -194,9 +204,9 @@ def get_log_prob(self, logits):
return log_probs
return log_probs.view(logits.size(0), logits.size(1), logits.size(2))

def predict_greedy(self, src, src_mask, pos=None):
def predict_greedy(self, src, src_mask, pos=None, raw=None):
""" Predict with greedy decoding. """
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos)
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)

# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
Expand Down Expand Up @@ -231,12 +241,12 @@ def predict_greedy(self, src, src_mask, pos=None):
output_seqs[i].append(token)
return output_seqs, edit_logits

def predict(self, src, src_mask, pos=None, beam_size=5):
def predict(self, src, src_mask, pos=None, beam_size=5, raw=None):
""" Predict with beam search. """
if beam_size == 1:
return self.predict_greedy(src, src_mask, pos=pos)
return self.predict_greedy(src, src_mask, pos, raw)

enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos)
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)

# (1) encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
Expand Down
9 changes: 5 additions & 4 deletions stanza/models/lemma/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def preprocess(self, data, char_vocab, pos_vocab, args):
tgt = list(d[2])
tgt_in = char_vocab.map([constant.SOS] + tgt)
tgt_out = char_vocab.map(tgt + [constant.EOS])
processed += [[src, tgt_in, tgt_out, pos, edit_type]]
processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]]
return processed

def __len__(self):
Expand All @@ -92,7 +92,7 @@ def __getitem__(self, key):
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 5
assert len(batch) == 6

# sort all fields by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
Expand All @@ -106,8 +106,9 @@ def __getitem__(self, key):
tgt_out = get_long_tensor(batch[2], batch_size)
pos = torch.LongTensor(batch[3])
edits = torch.LongTensor(batch[4])
text = batch[5]
assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match."
return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx
return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text

def __iter__(self):
for i in range(self.__len__()):
Expand All @@ -124,4 +125,4 @@ def resolve_none(self, data):
for feat_idx in range(len(data[tok_idx])):
if data[tok_idx][feat_idx] is None:
data[tok_idx][feat_idx] = '_'
return data
return data
52 changes: 41 additions & 11 deletions stanza/models/lemma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import torch.nn.init as init

import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.foundation_cache import load_charlm
from stanza.models.common.seq2seq_model import Seq2SeqModel
from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
from stanza.models.common import utils, loss
from stanza.models.lemma import edit
from stanza.models.lemma.vocab import MultiVocab
Expand All @@ -23,18 +25,24 @@ def unpack_batch(batch, device):
""" Unpack a batch from the data loader. """
inputs = [b.to(device) if b is not None else None for b in batch[:6]]
orig_idx = batch[6]
return inputs, orig_idx
text = batch[7]
return inputs, orig_idx, text

class Trainer(object):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None):
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None):
self.unsaved_modules = []
if model_file is not None:
# load everything from file
self.load(model_file)
self.load(model_file, args, foundation_cache)
else:
# build model from scratch
self.args = args
self.model = None if args['dict_only'] else Seq2SeqModel(args, emb_matrix=emb_matrix)
if args['dict_only']:
self.model = None
else:
self.model, charmodel = self.build_seq2seq(args, emb_matrix, foundation_cache)
self.add_unsaved_module("charmodel", charmodel)
self.vocab = vocab
# dict-based components
self.word_dict = dict()
Expand All @@ -48,17 +56,29 @@ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, devi
self.crit = loss.SequenceLoss(self.vocab['char'].size).to(device)
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])

def build_seq2seq(self, args, emb_matrix, foundation_cache):
charmodel = None
if args is not None and args.get('charlm_forward_file', None):
charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)
charmodel = CharacterLanguageModelWordAdapter(charmodel_forward)
model = Seq2SeqModel(args, emb_matrix=emb_matrix, contextual_embedding=charmodel)
return model, charmodel

def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)

def update(self, batch, eval=False):
device = next(self.model.parameters()).device
inputs, orig_idx = unpack_batch(batch, device)
inputs, orig_idx, text = unpack_batch(batch, device)
src, src_mask, tgt_in, tgt_out, pos, edits = inputs

if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos)
log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos, raw=text)
if self.args.get('edit', False):
assert edit_logits is not None
loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1), \
Expand All @@ -76,12 +96,12 @@ def update(self, batch, eval=False):

def predict(self, batch, beam_size=1):
device = next(self.model.parameters()).device
inputs, orig_idx = unpack_batch(batch, device)
inputs, orig_idx, text = unpack_batch(batch, device)
src, src_mask, tgt, tgt_mask, pos, edits = inputs

self.model.eval()
batch_size = src.size(0)
preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size)
preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size, raw=text)
pred_seqs = [self.vocab['char'].unmap(ids) for ids in preds] # unmap to tokens
pred_seqs = utils.prune_decoded_seqs(pred_seqs)
pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
Expand Down Expand Up @@ -182,7 +202,13 @@ def ensemble(self, pairs, other_preds):
lemmas.append(lemma)
return lemmas

def save(self, filename):
def save(self, filename, skip_modules=True):
model_state = self.model.state_dict()
# skip saving modules like the pretrained charlm
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': self.model.state_dict() if self.model is not None else None,
'dicts': (self.word_dict, self.composite_dict),
Expand All @@ -193,16 +219,20 @@ def save(self, filename):
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))

def load(self, filename):
def load(self, filename, args, foundation_cache):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None:
self.args['charlm_forward_file'] = args['charlm_forward_file']
self.args['charlm_backward_file'] = args['charlm_backward_file']
self.word_dict, self.composite_dict = checkpoint['dicts']
if not self.args['dict_only']:
self.model = Seq2SeqModel(self.args)
self.model, charmodel = self.build_seq2seq(self.args, None, foundation_cache)
self.add_unsaved_module("charmodel", charmodel)
# could remove strict=False after rebuilding all models,
# or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False
self.model.load_state_dict(checkpoint['model'], strict=False)
Expand Down
5 changes: 4 additions & 1 deletion stanza/models/lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def build_argparse():
parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.')
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.')

parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")

parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
Expand Down Expand Up @@ -236,7 +239,7 @@ def evaluate(args):
model_file = build_model_filename(args)

# load model
trainer = Trainer(model_file=model_file, device=args['device'])
trainer = Trainer(model_file=model_file, device=args['device'], args=args)
loaded_args, vocab = trainer.args, trainer.vocab

for k in args:
Expand Down
4 changes: 3 additions & 1 deletion stanza/pipeline/lemma_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def _set_up_model(self, config, pipeline, device):
# we make this an option, not the default
self.store_results = config.get('store_results', False)
self._use_identity = False
self._trainer = Trainer(model_file=config['model_path'], device=device)
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
'charlm_backward_file': config.get('backward_charlm_path', None)}
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)

def _set_up_requires(self):
self._pretagged = self._config.get('pretagged', None)
Expand Down

0 comments on commit 58a29e8

Please sign in to comment.