Skip to content

Commit

Permalink
[Fix] label_list not being set for NLP token classification training …
Browse files Browse the repository at this point in the history
…if distillation teacher and student labels do not match (#1414) (#1415)

* [Fix] Fix label_list not being set for NLP token classification training if distillation teacher and student labels do not match

* Added two fixes: omitting the labels/indices matching for student/teacher if teacher is a string; prioritizing teacher labels to student labels if teacher labels are string and student's int

* revert previous int label patch - allow int labels to let given dataset be the source of truth

* only override label_list when teacher and student labels sets are equal

---------

Co-authored-by: Mark Kurtz <mark.kurtz@neuralmagic.com>
Co-authored-by: Damian <damian@neuralmagic.com>
  • Loading branch information
3 people committed Mar 4, 2023
1 parent 2d0a6a9 commit a2839ce
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/sparseml/transformers/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,11 @@ def main(**kwargs):
},
)

if teacher:
if teacher and not isinstance(teacher, str):
# check whether teacher and student have the corresponding outputs
label_to_id, label_list = _check_teacher_student_outputs(teacher, label_to_id)
label_to_id, label_list = _check_teacher_student_outputs(
teacher, label_to_id, label_list
)

tokenizer_src = (
model_args.tokenizer_name
Expand Down Expand Up @@ -580,7 +582,7 @@ def compute_metrics(p):


def _check_teacher_student_outputs(
teacher: Module, label_to_id: Dict[str, int]
teacher: Module, label_to_id: Dict[str, int], label_list: List[str]
) -> Tuple[Dict[str, int], List[str]]:
# Check that the teacher and student have the same labels and if they do,
# check that the mapping between labels and ids is the same.
Expand Down Expand Up @@ -765,7 +767,9 @@ def _get_tokenized_dataset(
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
if isinstance(label, str) and (
label.startswith("B-") and label.replace("B-", "I-") in label_list
):
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
Expand Down

0 comments on commit a2839ce

Please sign in to comment.