diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index b352cc55e..d71fbc6bc 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -55,6 +55,7 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) +@torch.jit.script def get_cu_seqlens_from_pos_ids(position_ids): """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: @@ -81,7 +82,7 @@ def get_cu_seqlens_from_pos_ids(position_ids): # Get the indices where the sequence starts start_indices = torch.cat( [ - (seq_starts).nonzero(as_tuple=True)[0], + torch.nonzero(seq_starts).unbind(dim=1)[0], torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), ] )