From dc9d7681343ea690f9e54d1f20c9d53f116b481b Mon Sep 17 00:00:00 2001 From: Rhui Dih Lee Date: Fri, 28 Jun 2024 16:19:11 +0800 Subject: [PATCH] fix --- src/instructlab/training/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 1e452c28..32939170 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -148,14 +148,13 @@ def pad_collate_fn(batch): for idx in range(len(batch)): input_ids += batch[idx]["input_ids"] - if has_label: - labels += [-100] + batch[idx]["labels"][1:] - else: - labels += [-100] + batch[idx]["input_ids"][1:] + _label = batch[idx]["labels"].clone() if has_label else batch[idx]["labels"] + _label[0] = -100 + labels += _label position_ids += list(range(len(batch[idx]["input_ids"]))) - input_ids = torch.cat([x['input_ids'] for x in batch]).unsqueeze(0) - labels = torch.cat([x['labels'] for x in batch]).unsqueeze(0) + input_ids = torch.cat(input_ids).unsqueeze(0) + labels = torch.cat(labels).unsqueeze(0) position_ids = torch.tensor(position_ids, dtype=torch.long).unsqueeze(0) num_loss_counted_tokens = (labels != -100).sum()