Skip to content

Commit

Permalink
[Bug-fix] Text classifiation config (#1545)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed May 5, 2023
1 parent d823867 commit 641c830
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/sparseml/transformers/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def compute_metrics(p: EvalPrediction):
label_list=label_list,
model=model,
num_labels=num_labels,
config=config,
)
id_to_label = {id_: label for label, id_ in label_to_id.items()}

Expand Down Expand Up @@ -754,7 +755,7 @@ def _get_tokenized_and_preprocessed_raw_datasets(
# Some models have set the order of the labels to use, so let's make sure
# we do use it
label_to_id = _get_label_to_id(
data_args, is_regression, label_list, model, num_labels
data_args, is_regression, label_list, model, num_labels, config=config
)

if label_to_id is not None:
Expand Down Expand Up @@ -842,15 +843,16 @@ def preprocess_function(examples):
return tokenized_datasets, raw_datasets


def _get_label_to_id(data_args, is_regression, label_list, model, num_labels):
def _get_label_to_id(data_args, is_regression, label_list, model, num_labels, config):
label_to_id = None
config = model.config if model else config
if (
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
config.label2id != PretrainedConfig(num_labels=num_labels).label2id
and data_args.task_name is not None
and not is_regression
):
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
label_name_to_id = {k.lower(): v for k, v in config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {
i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)
Expand Down

0 comments on commit 641c830

Please sign in to comment.