Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch preset request: add access to c10::cuda::CUDACachingAllocator #1422

Closed
hmf opened this issue Oct 11, 2023 · 7 comments
Closed

Pytorch preset request: add access to c10::cuda::CUDACachingAllocator #1422

hmf opened this issue Oct 11, 2023 · 7 comments

Comments

@hmf
Copy link

hmf commented Oct 11, 2023

In the Scala Storch framework, we we are trying to determine memory usage to diagnose some issues. I would like to have an equivalent of Pytorch torch.cuda.memory_stats. It seems Libtorch has an equivalent of torch.cuda.memory_reserved().

Could we have these classes included in the JavaCPP Pytorch preset so that we can diagnose memory issues?

TIA

@HGuillemet
Copy link
Collaborator

CUDACachingAllocator has been added as part or PR #1426 .
Could you give it a try ?

@hmf
Copy link
Author

hmf commented Oct 18, 2023

@HGuillemet Thanks for adding that so quickly. Do we a have a snapshot that I can use?
I have looked here and found version 2.0.1-1.5.10-SNAPSHOT/, which was the previous version.

I also have an additional question - I see that the version above is dated Wed Oct 18 03:01:29 UTC 2023 in the snapshot repository. Does this mean it is rebuilt every day. Is this some update or is it the same as prior "releases"? Apologies if this question does not make sense.

EDIT: I see that the preset actions failed for pyTorch. Maybe this is why I did not find a snapshot.

TIA

@HGuillemet
Copy link
Collaborator

I don't think you can have snapshot before the PR is merged. You'll need to clone and compile yourself, or wait for the merge.

I believe that snapshot creation are triggered after each commit in the main repository (or manually by the repo maintainer).

@saudet Can you confirm ?

@hmf
Copy link
Author

hmf commented Oct 18, 2023

@HGuillemet Ok, I will wait for the snapshot then. Need to ensure storch compiles with the new version.

@hmf
Copy link
Author

hmf commented Dec 2, 2023

@HGuillemet I am using the 2.1 snapshot and can now see the CUDACachingAllocator classes. However, I am having a little trouble getting this to work. Essentially I think I need to (re)implement the THCPModule_memoryStats as is used by the Python code. Their are 2 issues which stem from my not understanding the use of Pointer.

The first is the call to:

 const DeviceStats stats =
      c10::cuda::CUDACachingAllocator::getDeviceStats(device);

This looks like a static call. However, in Java I seem to need something like:

    val devS = CUDAAllocator(Pointer()).getDeviceStats(1)

The problem is that this fails with:

Exception in thread "main" java.lang.NullPointerException: This pointer address is NULL.
        at org.bytedeco.pytorch.cuda.CUDAAllocator.getDeviceStats(Native Method)

So, what pointer must I use?

The other issue I have is when I try to implement:

  result["allocation"] = statArrayToDict(stats.allocation);

with:

    val devAllocation: BoolPointer = devS.allocation()

I was expecting an array of Stat. How do we access this array of Stat? No StatPointer seems to exist.

If you think it best, I can open another issue for this or use the discussion forum.

TIA

@HGuillemet
Copy link
Collaborator

HGuillemet commented Dec 2, 2023

There is a static method is in torch_cuda: getAllocator that returns a CUDAAllocator, on which you can call getDeviceStats.

The mapping of std::array<Stat, 3> is missing. I'll add it. In the meantime, you should be able to cast the BoolPointer to a Stat using:

Stat statArray = new Stat(devs.allocation());

You must see any instance of a subclass of Pointer as a C pointer. It can point to a single object or an array of objects, like in C. In this precise case you will get an array of 3 objects that you can access individually with position.

long c = statArray.position(2).current();

@hmf
Copy link
Author

hmf commented Dec 4, 2023

@HGuillemet Thank you very much. This is working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants