Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor's device mismatch #96

Open
hzcheney opened this issue Mar 15, 2022 · 3 comments
Open

Tensor's device mismatch #96

hzcheney opened this issue Mar 15, 2022 · 3 comments

Comments

@hzcheney
Copy link
Contributor

Hi! I have found a bug during the training of the caster model. It was caused by the torch.eye manipulation, simply it did not specify the device. When the Cuda is available, torch.eye will create the tensor on the CPU while the whole model is on the GPU.

@cthoyt
Copy link
Contributor

cthoyt commented Mar 15, 2022

Can you please give a code example that reproduces this error as well as copying the full stack trace?

@hzcheney
Copy link
Contributor Author

To reproduce:

  • Just run the caster_example.py file and you will get this error below.
  0%|                                                    | 0/10 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/hzcheney/DGL/chemicalx/examples/caster_example.py", line 30, in <module>
    main()
  File "/home/hzcheney/DGL/chemicalx/examples/caster_example.py", line 13, in main
    results = pipeline(
  File "/home/hzcheney/DGL/chemicalx/chemicalx/pipeline.py", line 165, in pipeline
    prediction = model(*model.unpack(batch))
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hzcheney/DGL/chemicalx/chemicalx/models/caster.py", line 124, in forward
    dictionary_features_latent = self.encoder(torch.eye(self.drug_channels))
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_addmm)

Process finished with exit code 1

Possible solution

I have solved this bug by specifying the device when using torch.eye. Just change this

dict_feat_squared_inv = torch.inverse(dict_feat_squared + self.lambda3 * (torch.eye(self.drug_channels)))

dictionary_features_latent = self.encoder(torch.eye(self.drug_channels))

to

dict_feat_squared_inv = torch.inverse(dict_feat_squared + self.lambda3 * (torch.eye(self.drug_channels, device=drug_pair_features_latent.device)))
dictionary_features_latent = self.encoder(torch.eye(self.drug_channels, device=drug_pair_features.device))

Another similar bug

To reproduce this one, just run the mhcaddi_example.py file and you will get this error below:

  0%|                                                    | 0/10 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 57, in _wrapfunc
    return bound(*args, **kwds)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/_tensor.py", line 680, in __array__
    return self.numpy().astype(dtype, copy=False)
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hzcheney/DGL/chemicalx/examples/mhcaddi_example.py", line 26, in <module>
    main()
  File "/home/hzcheney/DGL/chemicalx/examples/mhcaddi_example.py", line 13, in main
    results = pipeline(
  File "/home/hzcheney/DGL/chemicalx/chemicalx/pipeline.py", line 165, in pipeline
    prediction = model(*model.unpack(batch))
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hzcheney/DGL/chemicalx/chemicalx/models/mhcaddi.py", line 398, in forward
    outer_segmentation_index_left, outer_index_left, atom_left, bond_left = self._get_molecule_features(
  File "/home/hzcheney/DGL/chemicalx/chemicalx/models/mhcaddi.py", line 374, in _get_molecule_features
    outer_segmentation_index, outer_index = self.generate_outer_segmentation(
  File "/home/hzcheney/DGL/chemicalx/chemicalx/models/mhcaddi.py", line 461, in generate_outer_segmentation
    outer_segmentation_index = [
  File "/home/hzcheney/DGL/chemicalx/chemicalx/models/mhcaddi.py", line 462, in <listcomp>
    np.repeat(np.array(range(0, left_graph_size)), right_graph_size)
  File "<__array_function__ internals>", line 5, in repeat
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 479, in repeat
    return _wrapfunc(a, 'repeat', repeats, axis=axis)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 66, in _wrapfunc
    return _wrapit(obj, method, *args, **kwds)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 43, in _wrapit
    result = getattr(asarray(obj), method)(*args, **kwds)
  File "/home/hzcheney/miniconda3/envs/chemicalx/lib/python3.8/site-packages/torch/_tensor.py", line 680, in __array__
    return self.numpy().astype(dtype, copy=False)
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

I think it was caused by the numpy manipulation, since the numpy did not allowed to be compute on the GPUs.

outer_segmentation_index = [
np.repeat(np.array(range(0, left_graph_size)), right_graph_size)
for left_graph_size, right_graph_size in zip(graph_sizes_left, graph_sizes_right)
]

@Zilu-Zhang
Copy link

Zilu-Zhang commented Jul 8, 2022

Hi all, any updates about the numpy issue? @hzcheney @cthoyt @benedekrozemberczki

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants