Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RhuiDih committed Jun 28, 2024
1 parent db040f3 commit 5d3fa4a
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,14 @@ def pad_collate_fn(batch):
has_label = "labels" in batch[0]

for idx in range(len(batch)):
input_ids += batch[idx]["input_ids"]
if has_label:
labels += [-100] + batch[idx]["labels"][1:]
else:
labels += [-100] + batch[idx]["input_ids"][1:]
input_ids.append(batch[idx]["input_ids"])
_label = batch[idx]["labels"] if has_label else batch[idx]["input_ids"].clone()
_label[0] = -100
labels.append(_label)
position_ids += list(range(len(batch[idx]["input_ids"])))

input_ids = torch.cat([x['input_ids'] for x in batch]).unsqueeze(0)
labels = torch.cat([x['labels'] for x in batch]).unsqueeze(0)
input_ids = torch.cat(input_ids).unsqueeze(0)
labels = torch.cat(labels).unsqueeze(0)
position_ids = torch.tensor(position_ids, dtype=torch.long).unsqueeze(0)

num_loss_counted_tokens = (labels != -100).sum()
Expand Down

0 comments on commit 5d3fa4a

Please sign in to comment.