diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 1e452c28..32939170 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -148,14 +148,13 @@ def pad_collate_fn(batch): for idx in range(len(batch)): input_ids += batch[idx]["input_ids"] - if has_label: - labels += [-100] + batch[idx]["labels"][1:] - else: - labels += [-100] + batch[idx]["input_ids"][1:] + _label = batch[idx]["labels"].clone() if has_label else batch[idx]["labels"] + _label[0] = -100 + labels += _label position_ids += list(range(len(batch[idx]["input_ids"]))) - input_ids = torch.cat([x['input_ids'] for x in batch]).unsqueeze(0) - labels = torch.cat([x['labels'] for x in batch]).unsqueeze(0) + input_ids = torch.cat(input_ids).unsqueeze(0) + labels = torch.cat(labels).unsqueeze(0) position_ids = torch.tensor(position_ids, dtype=torch.long).unsqueeze(0) num_loss_counted_tokens = (labels != -100).sum()