From 4128b9e54f622c1ea14a8109b167d64427b3c6ba Mon Sep 17 00:00:00 2001 From: Konstantin Date: Thu, 4 May 2023 10:43:29 +0000 Subject: [PATCH] [Bug-fix] Text classifiation config --- src/sparseml/transformers/text_classification.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sparseml/transformers/text_classification.py b/src/sparseml/transformers/text_classification.py index 9246d2cb9fc..f45a31f42c9 100644 --- a/src/sparseml/transformers/text_classification.py +++ b/src/sparseml/transformers/text_classification.py @@ -497,6 +497,7 @@ def compute_metrics(p: EvalPrediction): label_list=label_list, model=model, num_labels=num_labels, + config=config, ) id_to_label = {id_: label for label, id_ in label_to_id.items()} @@ -754,7 +755,7 @@ def _get_tokenized_and_preprocessed_raw_datasets( # Some models have set the order of the labels to use, so let's make sure # we do use it label_to_id = _get_label_to_id( - data_args, is_regression, label_list, model, num_labels + data_args, is_regression, label_list, model, num_labels, config=config ) if label_to_id is not None: @@ -842,15 +843,16 @@ def preprocess_function(examples): return tokenized_datasets, raw_datasets -def _get_label_to_id(data_args, is_regression, label_list, model, num_labels): +def _get_label_to_id(data_args, is_regression, label_list, model, num_labels, config): label_to_id = None + config = model.config if model else config if ( - model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + config.label2id != PretrainedConfig(num_labels=num_labels).label2id and data_args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. - label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + label_name_to_id = {k.lower(): v for k, v in config.label2id.items()} if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): label_to_id = { i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)