Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow MPT models to return attention weights #599

Merged
merged 15 commits into from
Sep 21, 2023
Merged

Conversation

lorabit110
Copy link
Contributor

Previously, output_attentions is not propagated to the attention layer. Even when it's set to True, no attention weights were returned.

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a simple unit test for this?

@vchiley
Copy link
Contributor

vchiley commented Sep 15, 2023

This error should be noted if you plan to use it.

@lorabit110
Copy link
Contributor Author

Could you please add a simple unit test for this?

Updated an existing unit test to check attention.

@lorabit110
Copy link
Contributor Author

This error should be noted if you plan to use it.

Yeah. I am aware of that flash attention won't return attention weights.

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment to make sure we're returning the right shape of stuff, but otherwise lgtm! Thanks for the PR!

tests/test_model.py Outdated Show resolved Hide resolved
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
@dakinggg dakinggg enabled auto-merge (squash) September 16, 2023 00:20
tests/test_model.py Outdated Show resolved Hide resolved
@dakinggg dakinggg enabled auto-merge (squash) September 16, 2023 05:37
@dakinggg
Copy link
Collaborator

@lorabit110 could you run precommit run --all-files locally? Thanks!

tests/test_model.py Outdated Show resolved Hide resolved
@dakinggg dakinggg merged commit 0be2ca8 into mosaicml:main Sep 21, 2023
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants