From 5d3fa4a118ed0d883d218417a7b916628cc4358c Mon Sep 17 00:00:00 2001 From: Rhui Dih Lee Date: Fri, 28 Jun 2024 16:19:11 +0800 Subject: [PATCH] fix --- src/instructlab/training/utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 1e452c28..d2a87438 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -147,15 +147,14 @@ def pad_collate_fn(batch): has_label = "labels" in batch[0] for idx in range(len(batch)): - input_ids += batch[idx]["input_ids"] - if has_label: - labels += [-100] + batch[idx]["labels"][1:] - else: - labels += [-100] + batch[idx]["input_ids"][1:] + input_ids.append(batch[idx]["input_ids"]) + _label = batch[idx]["labels"] if has_label else batch[idx]["input_ids"].clone() + _label[0] = -100 + labels.append(_label) position_ids += list(range(len(batch[idx]["input_ids"]))) - input_ids = torch.cat([x['input_ids'] for x in batch]).unsqueeze(0) - labels = torch.cat([x['labels'] for x in batch]).unsqueeze(0) + input_ids = torch.cat(input_ids).unsqueeze(0) + labels = torch.cat(labels).unsqueeze(0) position_ids = torch.tensor(position_ids, dtype=torch.long).unsqueeze(0) num_loss_counted_tokens = (labels != -100).sum()