Skip to content

Commit

Permalink
fix: fix QAdam gradient is not BaguaTensor during first stage
Browse files Browse the repository at this point in the history
  • Loading branch information
NOBLES5E authored Jul 2, 2021
1 parent f61305c commit 1d4dc82
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def init_tensors(self, bagua_module: BaguaModule):
registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._one_bit_name
)
param._one_bit_grad = registered_tensor
else:
registered_tensor = exp_avgs.ensure_bagua_tensor(
param._one_bit_name
Expand Down Expand Up @@ -192,7 +193,8 @@ def hook_momentum(parameter_name, parameter):
parameter._one_bit_momentum.bagua_mark_communication_ready()

def hook_grad(parameter_name, parameter):
parameter.grad.bagua_mark_communication_ready()
assert parameter.grad.data_ptr() == parameter._one_bit_grad.data_ptr(), "gradient data_ptr should match _one_bit_grad data_ptr"
parameter._one_bit_grad.bagua_mark_communication_ready()

return (
hook_grad if self.optimizer.step_id < self.warmup_steps else hook_momentum
Expand Down

0 comments on commit 1d4dc82

Please sign in to comment.