Skip to content

Commit

Permalink
Add a condition for nested_detach
Browse files Browse the repository at this point in the history
  • Loading branch information
haikuoxin authored Jul 9, 2024
1 parent 0abf5e8 commit 2b66034
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def nested_detach(tensors):
return type(tensors)(nested_detach(t) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
return tensors.detach()
return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors


def nested_xla_mesh_reduce(tensors, name):
Expand Down

0 comments on commit 2b66034

Please sign in to comment.