Skip to content

Commit

Permalink
Sorting by length makes the parser about 10% faster, with most of the…
Browse files Browse the repository at this point in the history
… speedup coming from better batching of the input encoding. Will need to see if we can better batch the later operations as well
  • Loading branch information
AngledLuffa committed Aug 23, 2024
1 parent d4b5447 commit d14a7fa
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from stanza.models.common import utils
from stanza.models.common.foundation_cache import FoundationCache, NoTransformerFoundationCache
from stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss
from stanza.models.common.utils import sort_with_indices, unsort
from stanza.models.constituency import parse_transitions
from stanza.models.constituency import transition_sequence
from stanza.models.constituency import tree_reader
Expand Down Expand Up @@ -680,8 +681,10 @@ def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
num_generate = args.get('num_generate', 0)
keep_scores = num_generate > 0

tree_iterator = iter(tqdm(retagged_trees))
sorted_trees, original_indices = sort_with_indices(retagged_trees, key=len, reverse=True)
tree_iterator = iter(tqdm(sorted_trees))
treebank = model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.predict, keep_scores=keep_scores)
treebank = unsort(treebank, original_indices)
full_results = treebank

if num_generate > 0:
Expand Down
3 changes: 3 additions & 0 deletions stanza/pipeline/constituency_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from stanza.models.constituency.trainer import Trainer

from stanza.models.common import doc
from stanza.models.common.utils import sort_with_indices, unsort
from stanza.utils.get_tqdm import get_tqdm
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
Expand Down Expand Up @@ -54,10 +55,12 @@ def process(self, document):
words = [[(w.text, w.xpos) for w in s.words] for s in sentences]
else:
words = [[(w.text, w.upos) for w in s.words] for s in sentences]
words, original_indices = sort_with_indices(words, key=len, reverse=True)
if self._tqdm:
words = tqdm(words)

trees = self._model.parse_tagged_words(words, self._batch_size)
trees = unsort(trees, original_indices)
document.set(CONSTITUENCY, trees, to_sentence=True)
return document

Expand Down

0 comments on commit d14a7fa

Please sign in to comment.