-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default to allow_tf32=True
for GPU Devices
#1275
Conversation
PyTorch 1.12 disabled TF32 format by default for matmuls. This could lead to a significant performance regression. To fix, adding a flag in the `DeviceGPU` class to control whether to set PyTorch's `allow_tf32` flag. By default, Composer will set this to True, for consistent behavior across pytorch versions. Closes https://mosaicml.atlassian.net/browse/CO-328
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! For posterity, I want to note that we put this flag in the DeviceGPU
rather than as a Trainer precision enum, because it reflects how certain devices (NVIDIA Ampere+ GPUs ) perform "single precision" GEMM ops, rather than an overall training loop strategy for training with a particular precision.
At no point will any tensor or checkpoint or activation ever be stored in TF32, the format only exists inside NVIDIA tensor cores right before the inner products of a GEMM are computed, and then the accumulation and output tensor are true FP32. See this article
An analogy would be if we implemented DeviceTPU
, we might want an allow_bf16=True
flag which is forced to True
, which is not the same as BF16 mixed precision training. In their default behavior, TPUs cast FP32 tensors to BF16 right before the GEMM inner products, and output FP32. The overall training strategy of BF16 mixed precision (with BF16 activations and gradients) is a separate step. See this article
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
PyTorch 1.12 disabled TF32 format by default for matmuls. This could lead to a significant performance regression.
To fix, adding a flag in the
DeviceGPU
class to control whether to set PyTorch'sallow_tf32
flag. By default, Composer will set this to True, for consistent behavior across pytorch versions.Closes https://mosaicml.atlassian.net/browse/CO-328