forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix NJT serialization (pytorch#137031)
Fixes pytorch#129366 Since NJT has custom serialization logic, we need an NJT-specific fix to clear out cached sizes / strides PyCapsules. Eventually, we should switch NJT to use the default serialization logic, but this depends on pytorch#125622 being addressed. This PR also makes serialization more complete by explicitly handling `lengths`, `ragged_idx`, and the `metadata_cache`, ensuring working operation for both contiguous and non-contiguous NJTs, Pull Request resolved: pytorch#137031 Approved by: https://github.com/soulitzer ghstack dependencies: pytorch#137030
- Loading branch information
1 parent
be423a8
commit e95b230
Showing
2 changed files
with
110 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters