Skip to content

Commit

Permalink
Add a filter for just the models that have transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Aug 26, 2023
1 parent 3c8e44d commit 186011c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions stanza/utils/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def build_argparse(sub_argparse=None):
parser.add_argument('--save_name', type=str, default=None, help="Base name for saving models. If set, will override the model's default.")

parser.add_argument('--charlm_only', action='store_true', default=False, help='When asking for ud_all, filter the ones which have charlms')
parser.add_argument('--transformer_only', action='store_true', default=False, help='When asking for ud_all, filter the ones for languages where we have transformers')

parser.add_argument('--force', dest='force', action='store_true', default=False, help='Retrain existing models')
return parser
Expand Down Expand Up @@ -120,6 +121,9 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar
logger.info("Filtering ud_all treebanks to only those which can use charlm for this model")
ud_treebanks = [x for x in ud_treebanks
if choose_charlm_method(*treebank_to_short_name(x).split("_", 1), 'default') is not None]
if command_args.transformer_only:
logger.info("Filtering ud_all treebanks to only those which can use a transformer for this model")
ud_treebanks = [x for x in ud_treebanks if treebank_to_short_name(x).split("_")[0] in TRANSFORMERS]
logger.info("Expanding %s to %s", treebank, " ".join(ud_treebanks))
treebanks.extend(ud_treebanks)
else:
Expand Down

0 comments on commit 186011c

Please sign in to comment.