Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RhuiDih committed Jun 28, 2024
1 parent db040f3 commit 2868eab
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ def pad_collate_fn(batch):

for idx in range(len(batch)):
input_ids += batch[idx]["input_ids"]
if has_label:
labels += [-100] + batch[idx]["labels"][1:]
else:
labels += [-100] + batch[idx]["input_ids"][1:]
_label = batch[idx]["labels"] if has_label else batch[idx]["input_ids"].clone()
_label[0] = -100
labels += _label
position_ids += list(range(len(batch[idx]["input_ids"])))

input_ids = torch.cat([x['input_ids'] for x in batch]).unsqueeze(0)
labels = torch.cat([x['labels'] for x in batch]).unsqueeze(0)
input_ids = torch.cat(input_ids).unsqueeze(0)
labels = torch.cat(labels).unsqueeze(0)
position_ids = torch.tensor(position_ids, dtype=torch.long).unsqueeze(0)

num_loss_counted_tokens = (labels != -100).sum()
Expand Down

0 comments on commit 2868eab

Please sign in to comment.