Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch2 (#177) #178

Merged
merged 7 commits into from
May 19, 2023
Merged

Torch2 (#177) #178

merged 7 commits into from
May 19, 2023

Conversation

vchiley
Copy link
Contributor

@vchiley vchiley commented May 19, 2023

move #149 to main repo (from a fork)

uses #147 as a springboard to updt torch

In interactive instance, I installed torch2 req and everything works fine

125M models was getting good (the same) MFU from the same exact config in both torch1.13 and torch2

Note: torch2 version pip list has both triton version:

torch                  2.0.1+cu118

triton                 2.0.0
triton-pre-mlir        2.0.0

doesn't seem to matter

Note: this does not use torch.compile() (but there is no reason it shouldn't)

Note: flash-attn is still installed. xentropy-cuda-lib is also still installed; I'm not setting loss_fn so mpt defaults to using fused_crossentropy for both settings.

Biggest low probability risk: this old version of triton does not compile / work for H100s... 👀
Risk: triton_pre_mlir has no support and will never be updated.

Still need to test at scale / convergence
see torch2 vs torch1.13 produce the same results here

cc @sashaDoubov (enables torch2 for muP dev)
cc @dskhudia enables torch2 and torch.compile() with triton attn impl

old pr commits:

  • make triton attn req mlri tagged triton

  • add comment

  • updt err

  • clean up req / install

  • exclude HazyR flash attn from pyright

  • lint

  • exclude flash_attn_triton.py from pyright

  • updt torch version & install instructions

  • add extra install instructions for installing CMake

  • lint

  • adding torch1.13 and torch2 testing matrix

* make triton attn req mlri tagged triton

* add comment

* updt err

* clean up req / install

* updt

* updt

* exclude HazyR flash attn from pyright

* lint

* exclude flash_attn_triton.py from pyright

* updt torch version

* updt install instructions

* updt

* add extra install instructions for installing CMake

* lint

* updt

* updt torch

* updt

* adding torch1.13 and torch2 testing matrix
tests/test_model.py Show resolved Hide resolved
.github/workflows/pr-gpu.yaml Outdated Show resolved Hide resolved
@mvpatel2000
Copy link
Collaborator

image LGTM

@mvpatel2000 mvpatel2000 merged commit bb7f8bb into main May 19, 2023
7 checks passed
@mvpatel2000 mvpatel2000 deleted the vitaliy/torch2 branch May 19, 2023 22:21
@vchiley vchiley mentioned this pull request May 19, 2023
dakinggg added a commit to dakinggg/llm-foundry that referenced this pull request May 20, 2023
dakinggg added a commit that referenced this pull request May 20, 2023
vchiley added a commit to vchiley/llm-foundry that referenced this pull request May 22, 2023
@vchiley vchiley restored the vitaliy/torch2 branch May 23, 2023 16:49
vchiley added a commit that referenced this pull request May 24, 2023
* fix and test

* Revert "Revert "Torch2 (#177) (#178)" (#181)"

This reverts commit 89f56d2.

* updt import try except

* updt hf model

* updt imports

* lint

* add mpt hf model init / gen test

* updt for temp testing

* lint

* rerun tests

* Update .github/workflows/release.yaml

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* Update tests/test_hf_mpt_gen.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* add cpu test

* updt tests / cpu img

* updt cpu test install

* rerun tests

* fix hf import structure

* fix test

* pull_request -> pull_request_target

* make onnx test smaller

---------

Co-authored-by: Daniel King <daniel@mosaicml.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
bmosaicml pushed a commit that referenced this pull request Jun 6, 2023
* Torch2 (#177)

* make triton attn req mlri tagged triton

* add comment

* updt err

* clean up req / install

* updt

* updt

* exclude HazyR flash attn from pyright

* lint

* exclude flash_attn_triton.py from pyright

* updt torch version

* updt install instructions

* updt

* add extra install instructions for installing CMake

* lint

* updt

* updt torch

* updt

* adding torch1.13 and torch2 testing matrix

* Update pr-gpu.yaml

* Update test_model.py

* Update pr-cpu.yaml

* Update pr-gpu.yaml

* Update test_dataloader.py

* Update pr-gpu.yaml
bmosaicml pushed a commit that referenced this pull request Jun 8, 2023
bmosaicml pushed a commit that referenced this pull request Jun 8, 2023
* fix and test

* Revert "Revert "Torch2 (#177) (#178)" (#181)"

This reverts commit 601d61a.

* updt import try except

* updt hf model

* updt imports

* lint

* add mpt hf model init / gen test

* updt for temp testing

* lint

* rerun tests

* Update .github/workflows/release.yaml

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* Update tests/test_hf_mpt_gen.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* add cpu test

* updt tests / cpu img

* updt cpu test install

* rerun tests

* fix hf import structure

* fix test

* pull_request -> pull_request_target

* make onnx test smaller

---------

Co-authored-by: Daniel King <daniel@mosaicml.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants