Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In GPU mode generated image is all black with NaN tensor values (no problems in CPU mode) #31

Open
illtellyoulater opened this issue Mar 23, 2022 · 8 comments

Comments

@illtellyoulater
Copy link

illtellyoulater commented Mar 23, 2022

Hello,
For both "text2im.ipynb" and "clip_guided.ipynb" I'm seeing that the generated image is all black.
This only happens in GPU mode (Nvidia GTX 1660 TI, 6 GB), while in CPU mode the image is generated correctly.
I'm on Windows 10 using Python 3.8 and

torch-1.11.0+cu115 pypi_0 pypi
torchvision-0.12.0+cu115 pypi_0 pypi

and this environment works fine for all other ML projects I'm running.

In "text2im.ipynb" I saw that tensor values become NaN in the model_fn function, when model() is called:

# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = th.cat([half, half], dim=0)    
#-----
    # Values of 'combined' are not NaN
    model_out = model(combined, ts, **kwargs)
    # Values of 'model_out' are NaN
#-----        
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = th.cat([half_eps, half_eps], dim=0)
    return th.cat([eps, rest], dim=1)

As I tried to track down the problem a bit further, I found that the values start getting wrong in the forward function of "text2im_model.py":

def forward(self, x, timesteps, tokens=None, mask=None):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.xf_width:
text_outputs = self.get_text_emb(tokens, mask)
xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
emb = emb + xf_proj.to(emb)
else:
xf_out = None
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, xf_out)
hs.append(h)
h = self.middle_block(h, emb, xf_out)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, xf_out)
h = h.type(x.dtype)
h = self.out(h)
return h

specifically at line 133, where module is called:

for module in self.input_blocks:
h = module(h, emb, xf_out)
hs.append(h)

Here, at iteration # 2 some values become NaN and at iteration # 6 all values become NaN.

Please take a look:

----------------- INSIDE FOR LOOP, iteration #:  1 
----------------- INSIDE FOR LOOP, value of 'h' before module call: 

 tensor([[[[ 0.9609,  0.4629, -0.9834,  ...,  1.6162, -0.5767, -0.4253],
          [ 0.5947, -0.8301,  1.7686,  ..., -2.5215,  0.2920, -0.2183],
          [ 1.9561, -0.8403,  0.4053,  ...,  0.4990, -2.0176, -0.2935],
          ...,
          [ 1.8125, -0.4285,  0.1121,  ..., -1.1416, -2.6562, -1.1348],
          [ 0.9204, -0.4434, -0.1824,  ...,  0.2864,  1.7188, -0.8999],
          [ 1.8369,  0.2583,  0.4895,  ...,  1.4004,  1.5371,  2.8203]],

         [[ 1.7607,  0.4749,  1.9160,  ..., -0.6079, -0.5513, -3.0527],
          [ 0.9780,  1.3984,  1.7266,  ...,  0.2903, -0.7969, -1.4316],
          [-0.5293, -2.6465, -1.6699,  ..., -0.2900, -1.6738,  0.6704],
          ...,
          [ 0.0657, -0.7827,  1.1904,  ..., -0.3643,  0.7754, -0.8740],
          [ 1.0801, -1.1260, -0.1700,  ...,  1.4443, -0.3196, -0.1392],
          [-1.0645,  1.0898, -0.3838,  ...,  0.3491,  0.4077, -1.4492]],

         [[ 0.1176,  0.6514,  0.8452,  ...,  1.3486, -2.3496, -0.1377],
          [-1.6523, -0.1711, -0.1355,  ...,  1.2236,  1.0068,  1.9863],
          [ 0.7456,  1.1943,  0.1819,  ..., -2.1719,  1.7148,  0.0917],
          ...,
          [ 0.4253, -1.0078,  0.7847,  ...,  1.1348,  0.8101,  0.7744],
          [-1.1299, -0.0173, -0.5522,  ...,  0.3960,  1.0762,  0.1404],
          [-0.0644, -0.0656,  1.1670,  ..., -0.1234,  0.6870, -0.5278]]],
...
device='cuda:0', dtype=torch.float16)

----------------- INSIDE FOR LOOP, module function that will now be called is: 

 TimestepEmbedSequential(
  (0): Conv2d(3, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
----------------- INSIDE FOR LOOP, value of 'h' after module call: 

 tensor([[[[-0.3325, -0.4204, -1.3887,  ...,  0.0850, -0.1570, -0.6255],
          [ 0.5010, -0.4548,  0.2632,  ..., -1.8027, -0.2144, -1.4512],
          [ 0.1343, -1.0498,  0.4097,  ..., -0.0427, -2.1836, -0.3203],
          ...,
          [-0.2983, -0.2622, -1.0098,  ..., -1.7773, -1.7871, -1.3760],
          [ 0.1865, -0.8691, -0.1841,  ..., -0.5342, -0.8232, -1.7949],
          [ 0.4858, -0.7051, -0.7515,  ...,  0.7300,  0.0771,  0.6509]],

         [[-0.5107, -0.1924,  0.4790,  ..., -1.6797,  1.5586, -1.1074],
          [-0.8438, -1.3945, -0.8652,  ..., -0.1021, -1.9297, -1.8242],
          [-1.6289,  0.6030, -1.5410,  ...,  1.0488, -0.4473,  0.7524],
          ...,
          [-2.0586,  0.6978, -1.9316,  ..., -1.4785,  1.0742,  0.2190],
          [-1.0010, -0.6309,  0.3979,  ...,  0.3286, -0.3005,  0.8218],
          [-1.4961, -1.0723, -1.5293,  ...,  1.8125, -0.7954, -0.2915]],
...
device='cuda:0', dtype=torch.float16)

----------------- INSIDE FOR LOOP, iteration #:  2 
----------------- INSIDE FOR LOOP, value of 'h' before module call: 

 tensor([[[[-0.3325, -0.4204, -1.3887,  ...,  0.0850, -0.1570, -0.6255],
          [ 0.5010, -0.4548,  0.2632,  ..., -1.8027, -0.2144, -1.4512],
          [ 0.1343, -1.0498,  0.4097,  ..., -0.0427, -2.1836, -0.3203],
          ...,
          [-0.2983, -0.2622, -1.0098,  ..., -1.7773, -1.7871, -1.3760],
          [ 0.1865, -0.8691, -0.1841,  ..., -0.5342, -0.8232, -1.7949],
          [ 0.4858, -0.7051, -0.7515,  ...,  0.7300,  0.0771,  0.6509]],

         [[-0.5107, -0.1924,  0.4790,  ..., -1.6797,  1.5586, -1.1074],
          [-0.8438, -1.3945, -0.8652,  ..., -0.1021, -1.9297, -1.8242],
          [-1.6289,  0.6030, -1.5410,  ...,  1.0488, -0.4473,  0.7524],
          ...,
          [-2.0586,  0.6978, -1.9316,  ..., -1.4785,  1.0742,  0.2190],
          [-1.0010, -0.6309,  0.3979,  ...,  0.3286, -0.3005,  0.8218],
          [-1.4961, -1.0723, -1.5293,  ...,  1.8125, -0.7954, -0.2915]],
...
device='cuda:0', dtype=torch.float16)

----------------- INSIDE FOR LOOP, module function that will now be called is: 
 TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
      (1): Identity()
      (2): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=768, out_features=384, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0.1, inplace=False)
      (3): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
)

----------------- INSIDE FOR LOOP, value of 'h' after module call: 

 tensor([[[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          ...,
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
          [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

         [[-0.6113, -0.2927,  0.3787,  ..., -1.7803,  1.4580, -1.2080],
          [-0.9443, -1.4951, -0.9658,  ..., -0.2024, -2.0293, -1.9248],
          [-1.7295,  0.5024, -1.6416,  ...,  0.9482, -0.5479,  0.6519],
          ...,
          [-2.1582,  0.5972, -2.0312,  ..., -1.5791,  0.9736,  0.1186],
          [-1.1016, -0.7314,  0.2976,  ...,  0.2283, -0.4009,  0.7212],
          [-1.5967, -1.1729, -1.6299,  ...,  1.7119, -0.8960, -0.3918]],
...
device='cuda:0', dtype=torch.float16)

As you can see at this point only some values have become NaN.
This remain like so until iteration # 6, where, after the module call ALL values become NaN:

----------------- INSIDE FOR LOOP, iteration #:  6 
----------------- INSIDE FOR LOOP, value of 'h' before module call: 

 tensor([[[[        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          ...,
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan],
          [        nan,         nan,         nan,  ...,         nan,
                   nan,         nan]],

         [[-9.6973e-01,  3.6084e-01, -8.0078e-01,  ..., -6.1328e-01,
           -1.1406e+00, -1.0596e+00],
          [-4.0210e-01, -1.0947e+00, -2.0898e-01,  ..., -7.3730e-01,
           -6.4258e-01, -3.1860e-01],
          [-4.3530e-01, -4.1577e-01, -4.6655e-01,  ...,  5.1880e-02,
            1.5601e-01, -4.0283e-02],
          ...,
          [-5.6934e-01,  2.7954e-01, -1.4346e+00,  ..., -4.4751e-01,
           -1.3428e-02, -2.9565e-01],
          [-5.2148e-01, -6.8652e-01, -8.8770e-01,  ..., -2.4341e-01,
           -1.3213e+00,  2.9517e-01],
          [-1.2842e+00, -6.5234e-01, -1.9214e-01,  ..., -1.8779e+00,
           -3.9526e-01, -3.7500e-01]],
...
device='cuda:0', dtype=torch.float16)

----------------- INSIDE FOR LOOP, module function that will now be called is: 

 TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
      (1): Identity()
      (2): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=768, out_features=768, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 384, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0.1, inplace=False)
      (3): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))
  )
  (1): AttentionBlock(
    (norm): GroupNorm32(32, 384, eps=1e-05, affine=True)
    (qkv): Conv1d(384, 1152, kernel_size=(1,), stride=(1,))
    (attention): QKVAttention()
    (encoder_kv): Conv1d(512, 768, kernel_size=(1,), stride=(1,))
    (proj_out): Conv1d(384, 384, kernel_size=(1,), stride=(1,))
  )
)
----------------- INSIDE FOR LOOP, value of 'h' after module call: 

 tensor([[[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],
...
device='cuda:0', dtype=torch.float16)

With my limited knowledge of this field this is all I could find.
Please let me know if there some other info I can provide.

@illtellyoulater
Copy link
Author

illtellyoulater commented Mar 23, 2022

Seems to be related to this: pytorch/pytorch#58123 and specifically to this: pytorch/pytorch#58123 (comment)

However I don't understand how this is possibile, given the bug documented at the link above is affecting cudnn releases before v8.2.2 while I'm using cudnn v8.3.0.2, which I verified by doing:

import torch
torch.backends.cudnn.version()
8302

Any ideas?

@vanga
Copy link

vanga commented Mar 24, 2022

I don't have much technical insight to add, but I have seen such samples (everything black) being generated occassionally while doing transfer learning on Colab using below implementation.
https://github.com/afiaka87/glide-finetune

@illtellyoulater
Copy link
Author

illtellyoulater commented Mar 29, 2022

Ok I could finally make it work by installing a version of torch and torchvision coming with CUDA Toolkit v10.2.

Specifically I downloaded torch-1.8.0-cp38-cp38-win_amd64.whl and torchvision-0.9.0-cp38-cp38-win_amd64.whl (which include CUDA v10.2 despite not having the a "cu###" suffix) from https://download.pytorch.org/whl/torch_stable.html and then I installed them with pip install filename.whl.

In theory newer version of Torch should work too, provided they come with CUDA 10.2, eg:
pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 -f https://download.pytorch.org/whl/torch_stable.html

at least this was a requirement in my case...

@monsieurpooh
Copy link

monsieurpooh commented May 11, 2022

I hope this workaround works for my user who is suffering from the same issue, but... it is a very cumbersome workaround, because pytorch download page says 10.2 is no longer supported, and also, CUDA 10.2 won't work on newer GPU's such as 2070 and above, meaning the users have to download a separate thing.

At least one other source that downgrading should work: https://discuss.pytorch.org/t/half-precision-convolution-cause-nan-in-forward-pass/117358/3

Wondering if anyone knows of a better workaround that doesn't involve downgrading CUDA version

(edit: Wait a sec, I just realized my torch version is only 8200. Will update with more comments)

@illtellyoulater
Copy link
Author

illtellyoulater commented May 11, 2022

@monsieurpooh I observed that in some other cases the following newer version also works for me:

pip3 install torch==1.11.0+cu115 torchvision==0.12.0+cu115 torchaudio===0.11.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html

does it work in your case?

@monsieurpooh
Copy link

Thanks for your input!

Updating above cuDNN 8.2.2 was sufficient to fix my issue even with cuda toolkit 11.3. It was not necessary to downgrade CUDA toolkit.

I have not tried installing with cuda toolkit 11.5. But presumably that may also work; maybe the torch that has cuda 11.5 also has cuDNN of higher than 8.2.2? I noticed the default torch 1.11 only had cuDNN 8.2 or so.

@YipKo
Copy link

YipKo commented May 20, 2022

Same Problem here. (also same hardware)
@monsieurpooh
@illtellyoulater
Unlike you guys, I have tried pytorch with cuda version 11.5 (whose cudnn version is 8.3.0>8.2.2) and also tried downloading cuDNN from nvidia and copy/paste the dll files into the relevant folder in torch/lib , the problem can not be solved.

@bscout9956
Copy link

Same happens while trying Stable Diffusion with autocast/fp16.
The output is black no matter what I do. This is with pytorch 1.12.1 cuda 11.6 cudnn 8.0 on conda

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants