Skip to content

Commit

Permalink
Fix varlen generation by passing seq_idx to causal_conv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 3, 2024
1 parent ddce0c1 commit 8ffd905
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.1"
__version__ = "2.2.2"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down
2 changes: 2 additions & 0 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
conv_state.copy_(conv_varlen_states)
assert self.activation in ["silu", "swish"]
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
xBC = self.act(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
Expand All @@ -235,6 +236,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
y = mamba_chunk_scan_combined(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def test_generation_varlen():
sequences.append(sampled_tokens)
out_varlen = torch.cat(scores, dim=1)
print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
assert (out_varlen - out_ref).abs().max() < 5 * (out_loop - out_ref).abs().max()
assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()

0 comments on commit 8ffd905

Please sign in to comment.