-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[contrib] Improve FusedAdamSWA interface and add unit tests #1759
[contrib] Improve FusedAdamSWA interface and add unit tests #1759
Conversation
Maybe @crcrpar, would you please review? Thanks! |
kBF16 = 1 | ||
kFP32 = 2 | ||
kFP64 = 3 | ||
@unique |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't know there exists this decorator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is a neat decorator ensuring unique enum values ;)
params: List[nn.Parameter], | ||
compute_params: List[nn.Parameter], | ||
swa_params: List[nn.Parameter], | ||
swa_decay_rate: float, | ||
lr: float = 1e-3, | ||
bias_correction: bool = True, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-8, | ||
adam_math_mode: AdamMathType = AdamMathType.PyTorchAdam, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
set_grad_none: bool = True, | ||
capturable: bool = False, | ||
master_weights: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
chain(moments_gt, velocities_gt), | ||
) | ||
): | ||
assert torch.allclose(m, m_gt, rtol=rtol, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert torch.allclose(m, m_gt, rtol=rtol, atol=atol) | |
torch.testing.assert_close(m, m_gt, rtol=rtol, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed.
chain(state_params_gt, compute_params_gt, swa_params_gt), | ||
) | ||
): | ||
assert torch.allclose(p_test, p_gt, rtol=rtol, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert torch.allclose(p_test, p_gt, rtol=rtol, atol=atol) | |
torch.testing.assert_close(p_test, p_gt, rtol=rtol, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
d7101b4
to
e7afced
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idieally I want this to be compatible with python standard unittest module.
In this repository some file uses pytorch's TestCase class and runner as in https://github.com/NVIDIA/apex/blob/37d83fce4dcbb59897dfd951906493a6fe7fae37/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer. Have refactored to unittest.TestCase
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
Why? - FusedAdamSWA interface was loosely typed and error-prone - The training critical path of FusedAdamSWA (i.e., its step function) could contain unnecessary GPU-host sync when grad_clip_scale is set to a non-CUDA-tensor variable - FusedAdamSWA didn't have any unit test What? - Encapsulated FusedAdamSWA math types and internal numerical type into Python enumerations to improve type robustness and readability - Accept grad_clip_scale as either a tensor or a number, for the latter case we move it to GPU in a non-blocking manner to eliminate a GPU-host sync - Add unit test to guarentee numerical correctness and demostrate usage
e7afced
to
499bfdc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
Why?
What?