From 686e75459ba3c81332ae9f7969aefe63ca4209a8 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sat, 26 Aug 2023 11:58:07 -0700 Subject: [PATCH] Refactor the pretrain dependency calculation. Also, throw away the charlm / nocharlm / nopretrain particles - the mimic, craft, and genia models were not getting the right pretrain --- stanza/resources/prepare_resources.py | 52 +++++++++++++++------------ 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 3c6d69bb67..320f03ffe3 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -163,6 +163,23 @@ def get_con_dependencies(lang, package): return dependencies +def get_pretrain_package(lang, package, model_pretrains, default_pretrains): + pieces = package.split("_", 1) + if len(pieces) > 1: + if pieces[1] == 'nopretrain': + return None + package = pieces[0] + + if lang in no_pretrain_languages: + return None + elif lang in model_pretrains and package in model_pretrains[lang]: + return model_pretrains[lang][package] + elif lang in default_pretrains: + return default_pretrains[lang] + + raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package)) + + def get_pos_charlm_package(lang, package): pieces = package.split("_", 1) if len(pieces) > 1: @@ -176,17 +193,13 @@ def get_pos_charlm_package(lang, package): return default_charlms.get(lang, None) def get_pos_dependencies(lang, package): - if lang in no_pretrain_languages: - dependencies = [] - elif lang in pos_pretrains and package in pos_pretrains[lang]: - dependencies = [{'model': 'pretrain', 'package': pos_pretrains[lang][package]}] - elif lang in default_pretrains: - dependencies = [{'model': 'pretrain', 'package': default_pretrains[lang]}] - else: - raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package)) + dependencies = [] - charlm_package = get_pos_charlm_package(lang, package) + pretrain_package = get_pretrain_package(lang, package, pos_pretrains, default_pretrains) + if pretrain_package is not None: + dependencies.append({'model': 'pretrain', 'package': pretrain_package}) + charlm_package = get_pos_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) @@ -232,17 +245,13 @@ def get_depparse_charlm_package(lang, package): return default_charlms.get(lang, None) def get_depparse_dependencies(lang, package): - if lang in no_pretrain_languages: - dependencies = [] - elif lang in depparse_pretrains and package in depparse_pretrains[lang]: - dependencies = [{'model': 'pretrain', 'package': depparse_pretrains[lang][package]}] - elif lang in default_pretrains: - dependencies = [{'model': 'pretrain', 'package': default_pretrains[lang]}] - else: - raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package)) + dependencies = [] - charlm_package = get_depparse_charlm_package(lang, package) + pretrain_package = get_pretrain_package(lang, package, depparse_pretrains, default_pretrains) + if pretrain_package is not None: + dependencies.append({'model': 'pretrain', 'package': pretrain_package}) + charlm_package = get_depparse_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package}) @@ -252,12 +261,9 @@ def get_depparse_dependencies(lang, package): def get_ner_dependencies(lang, package): dependencies = [] - if lang in ner_pretrains and package in ner_pretrains[lang]: - pretrain_package = ner_pretrains[lang][package] - else: - pretrain_package = default_pretrains[lang] + pretrain_package = get_pretrain_package(lang, package, ner_pretrains, default_pretrains) if pretrain_package is not None: - dependencies = [{'model': 'pretrain', 'package': pretrain_package}] + dependencies.append({'model': 'pretrain', 'package': pretrain_package}) if lang not in ner_charlms or package not in ner_charlms[lang]: charlm_package = default_charlms.get(lang, None)