Skip to content

Commit

Permalink
Added two fixes: omitting the labels/indices matching for student/tea…
Browse files Browse the repository at this point in the history
…cher if teacher is a string; prioritizing teacher labels to student labels if teacher labels are string and student's int
  • Loading branch information
dbogunowicz committed Mar 3, 2023
1 parent f7508fe commit 8981bdc
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/sparseml/transformers/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def main(**kwargs):
},
)

if teacher:
if teacher and not isinstance(teacher, str):
# check whether teacher and student have the corresponding outputs
label_to_id, label_list = _check_teacher_student_outputs(teacher, label_to_id)

Expand Down Expand Up @@ -597,7 +597,21 @@ def _check_teacher_student_outputs(
f"student labels {student_labels}. Ignore this warning "
"if this is expected behavior."
)
label_list = student_labels
is_teacher_label_str = isinstance(teacher_labels[0], str)
is_student_label_str = isinstance(student_labels[0], str)

if is_teacher_label_str and not is_student_label_str:
# If the teacher labels are strings and the student labels are ints,
# we will assume that the teacher labels are the correct labels.
label_list = teacher_labels
models_labels_to_use = "teacher"
else:
label_list = student_labels
models_labels_to_use = "student"
_LOGGER.warning(
"From this point forward, the "
f"{models_labels_to_use}'s labels will be used."
)
else:
if student_ids != teacher_ids:
_LOGGER.warning(
Expand Down

0 comments on commit 8981bdc

Please sign in to comment.