Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TE HF checkpoint saving #1280

Merged
merged 13 commits into from
Jun 18, 2024
Merged

Fix TE HF checkpoint saving #1280

merged 13 commits into from
Jun 18, 2024

Conversation

j316chuck
Copy link
Contributor

@j316chuck j316chuck commented Jun 13, 2024

Description

Fixes HF Checkpoint callback for TransformerEngine FP8 saving. This PR ensures we serialize the io.BytesIO extra_state tensors as regular tensors insave_pretrained so the code does not error.

Tests

  • Added unit test, skipped on A100 GPU ✔️
  • Added unit test, manually ran on H100 GPU ✅
tests/a_scripts/inference/test_convert_composer_to_hf.py::test_huggingface_conversion_callback[1ba-1ba-1ba-1-1-amp_fp8-full-mpt-True-None]
  /usr/lib/python3/dist-packages/transformer_engine/pytorch/module/base.py:394: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:1524.)
    state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================================================================== 25 passed, 11 skipped, 1621 deselected, 266 warnings in 92.95s (0:01:32) ======================================================================================
Waiting up to 30 seconds for all training processes to terminate. Press Ctrl-C to exit immediately.
  • Before: failed-hf-checkpointer-fp8-llama3-8b-metamath-4ep-KOTaOP 🔴
  • After: success-hf-checkpointer-fp8-llama3-8b-metamath-4ep-yxNFTK

Issues

Closes https://databricks.atlassian.net/browse/RGENAI-255

@j316chuck j316chuck requested a review from a team as a code owner June 13, 2024 21:12
@j316chuck j316chuck changed the title Add fix for TE HF Ckpt Fix TE HF checkpoint saving Jun 13, 2024
Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can i load a TE ckpt into a non TEd model?

LGTM but also @dakinggg wdyt

@j316chuck
Copy link
Contributor Author

j316chuck commented Jun 14, 2024

@mvpatel2000 loading from fp8 and training with bf16 seens to work with test run example here: torch-231-bf16-load-from-fp8-bR8NzC.

Curious what the use case is in which you would do that though?

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will review fully once CI passes

Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but same comment on waiting for CI/CD topass

@j316chuck j316chuck requested a review from dakinggg June 18, 2024 03:45
@j316chuck j316chuck merged commit c23be4a into main Jun 18, 2024
10 of 11 checks passed
@dakinggg dakinggg deleted the chuck/te_hf_ckpt branch August 6, 2024 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants