Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variable-length sequences for mamba block with position indices #434

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ptxu78
Copy link

@ptxu78 ptxu78 commented Jul 1, 2024

Enable the mamba block to support variable-length sequence inputs using positional encoding. Passing Positional Indices results in negligible performance loss for the mamba block. For common variable-length sequence distributions, performance can be improved by 4-6x.

For example, a packaged sequence of length 16 consists of four independent sub-sequences with lengths of 3, 5, 6, 2. Then corresponding:
Cumulative sequence: [0, 3, 8, 14, 16]
Position encoding: [0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1]

In the Mamba module, there are two steps that are not sequence-wise: conv1d and selective_scan. Sub-sequences within the same sequence can affect each other, and we have modified causal-conv1d and selective_scan. These two CUDA operators are implemented using position encoding to eliminate the mutual influence between sub-sequences.

  • How to use:
  1. Setup two cuda_operators:
git clone --branch feat/pack_with_position_indices https://github.com/ptxu78/pack_mamba.git
python setup .../pack_mamba/setup_onlyCUDA.py install
git clone --branch feat/pack_with_position_indices https://github.com/ptxu78/causal-conv1d-pack.git
python setup .../causal-conv1d-pack/setup_onlyCUDA.py install
  1. Pack the variable-length sequences and input them, along with the corresponding position encoding.

…ions on packed data without interference between token sequences.
…-end pack experiments with the mamba block. Added support for position_indices in conv1d within mamba_inner_fn. The conv1d code can be found at https://github.com/ptxu78/causal-conv1d-pack/tree/feat/pack_with_position_indices.
@ptxu78 ptxu78 closed this Jul 1, 2024
@ptxu78 ptxu78 reopened this Jul 4, 2024
@ScottHoang
Copy link

This PR attempts to resolve the issue derived from training with packed data?

@ptxu78
Copy link
Author

ptxu78 commented Jul 25, 2024

This PR attempts to resolve the issue derived from training with packed data?

Yes, this PR allow Mamba to handle packed data more effectively: it significantly increases throughput while ensuring the mathematical equivalence of the training results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants