Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to allow_tf32=True for GPU Devices #1275

Merged
merged 2 commits into from
Jul 12, 2022

Conversation

ravi-mosaicml
Copy link
Contributor

PyTorch 1.12 disabled TF32 format by default for matmuls. This could lead to a significant performance regression.

To fix, adding a flag in the DeviceGPU class to control whether to set PyTorch's allow_tf32 flag. By default, Composer will set this to True, for consistent behavior across pytorch versions.

Closes https://mosaicml.atlassian.net/browse/CO-328

PyTorch 1.12 disabled TF32 format by default for matmuls.  This could lead to a significant performance regression.

To fix, adding a flag in the `DeviceGPU` class to control whether to set PyTorch's `allow_tf32` flag. By default, Composer will set this to True, for consistent behavior across pytorch versions.

Closes https://mosaicml.atlassian.net/browse/CO-328
Copy link
Contributor

@abhi-mosaic abhi-mosaic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! For posterity, I want to note that we put this flag in the DeviceGPU rather than as a Trainer precision enum, because it reflects how certain devices (NVIDIA Ampere+ GPUs ) perform "single precision" GEMM ops, rather than an overall training loop strategy for training with a particular precision.

At no point will any tensor or checkpoint or activation ever be stored in TF32, the format only exists inside NVIDIA tensor cores right before the inner products of a GEMM are computed, and then the accumulation and output tensor are true FP32. See this article

An analogy would be if we implemented DeviceTPU, we might want an allow_bf16=True flag which is forced to True, which is not the same as BF16 mixed precision training. In their default behavior, TPUs cast FP32 tensors to BF16 right before the GEMM inner products, and output FP32. The overall training strategy of BF16 mixed precision (with BF16 activations and gradients) is a separate step. See this article

Copy link
Contributor

@linden-li linden-li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@abhi-mosaic abhi-mosaic merged commit 924cf27 into mosaicml:dev Jul 12, 2022
@ravi-mosaicml ravi-mosaicml deleted the CO-328 branch July 13, 2022 01:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants