Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes bug in SelectiveScanFn.forward for when B is not variable and last_state is returned #371

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vidavakil
Copy link

test_selective_scan() fails when is_variable_B is False.

Turns out selective_scan_fwd_kernel incorporates an optimization of not multiplying the state by B if B is not variable. This does not impact MambaInnerFn, because MambaInnerFn never returns the state. But SelectiveScanFn may need to return the last_state. The changes to the code fix this problem, by multiplying the last_state by B before returning it when B is not variable.

…function has to return

the last_state. # The cuda kernel does a peculiar optimization of not multiplying the state
by B if B is not variable! This does not impact MambaInnerFn, because it never returns the
state. But SelectiveScanFn may needd to return the last state! Hence the following is needed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant