Skip to content

Commit

Permalink
replace concatenates with cat (#1781)
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 19, 2023
1 parent cae298c commit 05a300f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/sparseml/modifiers/obcq/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def ppl_eval_general(

vocabulary_size = logits[0].shape[-1]
logits = [logit[:, :-1, :].view(-1, vocabulary_size) for logit in logits]
logits = torch.concatenate(logits, dim=0).contiguous().to(torch.float32)
logits = torch.cat(logits, dim=0).contiguous().to(torch.float32)

labels = [sample[:, 1:].view(-1) for sample in samples]
labels = torch.concatenate(labels, dim=0).to(dev)
labels = torch.cat(labels, dim=0).to(dev)
neg_log_likelihood += torch.nn.functional.cross_entropy(
logits,
labels,
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/transformers/data/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _add_end_token(self, tokenized_sample):
if len(tokenized_sample) == self._seqlen:
tokenized_sample[-1] = self.tokenizer.eos_token_id
else:
tokenized_sample = torch.concatenate(
tokenized_sample = torch.cat(
(
tokenized_sample,
torch.tensor((self.tokenizer.eos_token_id,)),
Expand Down

0 comments on commit 05a300f

Please sign in to comment.