Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attribute error on _NotYetLoadedTensor loading checkpoint into quantized model #20119

Closed
rasbt opened this issue Jul 23, 2024 · 1 comment · Fixed by #20121
Closed

Attribute error on _NotYetLoadedTensor loading checkpoint into quantized model #20119

rasbt opened this issue Jul 23, 2024 · 1 comment · Fixed by #20121
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing precision: bnb Bitsandbytes quantization ver: 2.2.x
Milestone

Comments

@rasbt
Copy link
Collaborator

rasbt commented Jul 23, 2024

Bug description

When upgrading the lightning version from 2.3.0.dev20240428 to 2.3.3, we encounter an AttributeError: '_NotYetLoadedTensor' object has no attribute 'data'.

What version are you seeing the problem on?

master

How to reproduce the bug

litgpt generate --quantize bnb.nf4 checkpoints/microsoft/phi-2

Error messages and logs

⚡ main ~/litgpt2 litgpt generate --quantize bnb.nf4 checkpoints/microsoft/phi-2 
{'checkpoint_dir': PosixPath('checkpoints/microsoft/phi-2'),
 'compile': False,
 'max_new_tokens': 50,
 'num_samples': 1,
 'precision': None,
 'prompt': 'What food do llamas eat?',
 'quantize': 'bnb.nf4',
 'temperature': 0.8,
 'top_k': 50,
 'top_p': 1.0}
Loading model 'checkpoints/microsoft/phi-2/lit_model.pth' with {'name': 'phi-2', 'hf_config': {'name': 'phi-2', 'org': 'microsoft'}, 'scale_embeddings': False, 'block_size': 2048, 'vocab_size': 50257, 'padding_multiple': 512, 'padded_vocab_size': 51200, 'n_layer': 32, 'n_head': 32, 'head_size': 80, 'n_embd': 2560, 'rotary_percentage': 0.4, 'parallel_residual': True, 'bias': True, 'lm_head_bias': True, 'n_query_groups': 32, 'shared_attention_norm': True, 'norm_class_name': 'LayerNorm', 'norm_eps': 1e-05, 'mlp_class_name': 'GptNeoxMLP', 'gelu_approximate': 'tanh', 'intermediate_size': 10240, 'rope_condense_ratio': 1, 'rope_base': 10000, 'n_expert': 0, 'n_expert_per_token': 0, 'rope_n_elem': 32}
Time to instantiate model: 0.25 seconds.
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt2/litgpt/__main__.py", line 71, in main
    CLI(parser_data)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
    return component(**cfg)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/litgpt2/litgpt/generate/base.py", line 255, in main
    load_checkpoint(fabric, model, checkpoint_path)
  File "/teamspace/studios/this_studio/litgpt2/litgpt/utils.py", line 362, in load_checkpoint
    model.load_state_dict(state_dict, strict=strict)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 168, in load_state_dict
    return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2139, in load_state_dict
    load(self, state_dict)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2127, in load
    load(child, child_state_dict, child_prefix)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2121, in load
    module._load_from_state_dict(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1991, in _load_from_state_dict
    hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 72, in __call__
    return self.hook(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 166, in _quantize_on_load_hook
    quantize_fn(weight)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 320, in quantize_
    if weight.data.dtype == torch.uint8:
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/utilities/load.py", line 166, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: '_NotYetLoadedTensor' object has no attribute 'data'

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA A10G
        - available:         True
        - version:           12.1
* Lightning:
        - lightning:         2.3.3
        - lightning-cloud:   0.5.70
        - lightning-sdk:     0.1.10
        - lightning-utilities: 0.11.3.post0
        - pytorch-lightning: 2.3.3
        - torch:             2.2.1+cu121
        - torchmetrics:      1.3.1
        - torchvision:       0.17.1+cu121
* Packages:
        - absl-py:           2.1.0
        - accelerate:        0.32.1
        - aiohttp:           3.9.5
        - aiosignal:         1.3.1
        - annotated-types:   0.7.0
        - anyio:             4.4.0
        - argon2-cffi:       23.1.0
        - argon2-cffi-bindings: 21.2.0
        - arrow:             1.3.0
        - asttokens:         2.4.1
        - async-lru:         2.0.4
        - async-timeout:     4.0.3
        - attrs:             23.2.0
        - babel:             2.15.0
        - backoff:           2.2.1
        - beautifulsoup4:    4.12.3
        - bitsandbytes:      0.42.0
        - bleach:            6.1.0
        - boto3:             1.34.142
        - botocore:          1.34.142
        - cachetools:        5.3.3
        - certifi:           2024.7.4
        - cffi:              1.16.0
        - chardet:           5.2.0
        - charset-normalizer: 3.3.2
        - click:             8.1.7
        - colorama:          0.4.6
        - comm:              0.2.2
        - contourpy:         1.2.1
        - cycler:            0.12.1
        - dataproperty:      1.0.1
        - datasets:          2.20.0
        - debugpy:           1.8.2
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.8
        - dnspython:         2.6.1
        - docstring-parser:  0.16
        - email-validator:   2.2.0
        - evaluate:          0.4.2
        - exceptiongroup:    1.2.1
        - executing:         2.0.1
        - fastapi:           0.111.0
        - fastapi-cli:       0.0.4
        - fastjsonschema:    2.20.0
        - filelock:          3.15.4
        - fire:              0.6.0
        - fonttools:         4.53.1
        - fqdn:              1.5.1
        - frozenlist:        1.4.1
        - fsspec:            2024.5.0
        - google-auth:       2.32.0
        - google-auth-oauthlib: 1.2.1
        - grpcio:            1.64.1
        - h11:               0.14.0
        - hf-transfer:       0.1.6
        - httpcore:          1.0.5
        - httptools:         0.6.1
        - httpx:             0.27.0
        - huggingface-hub:   0.24.0
        - idna:              3.7
        - importlib-resources: 6.4.0
        - ipykernel:         6.26.0
        - ipython:           8.17.2
        - ipywidgets:        8.1.1
        - isoduration:       20.11.0
        - jedi:              0.19.1
        - jinja2:            3.1.4
        - jmespath:          1.0.1
        - joblib:            1.4.2
        - json5:             0.9.25
        - jsonargparse:      4.31.0
        - jsonlines:         4.0.0
        - jsonpointer:       3.0.0
        - jsonschema:        4.23.0
        - jsonschema-specifications: 2023.12.1
        - jupyter-client:    8.6.2
        - jupyter-core:      5.7.2
        - jupyter-events:    0.10.0
        - jupyter-lsp:       2.2.5
        - jupyter-server:    2.14.1
        - jupyter-server-terminals: 0.5.3
        - jupyterlab:        4.2.0
        - jupyterlab-pygments: 0.3.0
        - jupyterlab-server: 2.27.2
        - jupyterlab-widgets: 3.0.11
        - kiwisolver:        1.4.5
        - lightning:         2.3.3
        - lightning-cloud:   0.5.70
        - lightning-sdk:     0.1.10
        - lightning-utilities: 0.11.3.post0
        - litdata:           0.2.17
        - litgpt:            0.4.5
        - litserve:          0.1.3
        - lm-eval:           0.4.3
        - lxml:              5.2.2
        - markdown:          3.6
        - markdown-it-py:    3.0.0
        - markupsafe:        2.1.5
        - matplotlib:        3.8.2
        - matplotlib-inline: 0.1.7
        - mbstrdecoder:      1.1.3
        - mdurl:             0.1.2
        - mistune:           3.0.2
        - more-itertools:    10.3.0
        - mpmath:            1.3.0
        - multidict:         6.0.5
        - multiprocess:      0.70.16
        - nbclient:          0.10.0
        - nbconvert:         7.16.4
        - nbformat:          5.10.4
        - nest-asyncio:      1.6.0
        - networkx:          3.3
        - nltk:              3.8.1
        - notebook-shim:     0.2.4
        - numexpr:           2.10.1
        - numpy:             1.26.4
        - nvidia-cublas-cu12: 12.1.3.1
        - nvidia-cuda-cupti-cu12: 12.1.105
        - nvidia-cuda-nvrtc-cu12: 12.1.105
        - nvidia-cuda-runtime-cu12: 12.1.105
        - nvidia-cudnn-cu12: 8.9.2.26
        - nvidia-cufft-cu12: 11.0.2.54
        - nvidia-curand-cu12: 10.3.2.106
        - nvidia-cusolver-cu12: 11.4.5.107
        - nvidia-cusparse-cu12: 12.1.0.106
        - nvidia-nccl-cu12:  2.19.3
        - nvidia-nvjitlink-cu12: 12.5.82
        - nvidia-nvtx-cu12:  12.1.105
        - oauthlib:          3.2.2
        - orjson:            3.10.6
        - overrides:         7.7.0
        - packaging:         24.1
        - pandas:            2.1.4
        - pandocfilters:     1.5.1
        - parso:             0.8.4
        - pathvalidate:      3.2.0
        - peft:              0.11.1
        - pexpect:           4.9.0
        - pillow:            10.4.0
        - pip:               24.1.2
        - platformdirs:      4.2.2
        - portalocker:       2.10.1
        - prometheus-client: 0.20.0
        - prompt-toolkit:    3.0.47
        - protobuf:          4.23.4
        - psutil:            6.0.0
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           17.0.0
        - pyarrow-hotfix:    0.6
        - pyasn1:            0.6.0
        - pyasn1-modules:    0.4.0
        - pybind11:          2.13.1
        - pycparser:         2.22
        - pydantic:          2.8.2
        - pydantic-core:     2.20.1
        - pygments:          2.18.0
        - pyjwt:             2.8.0
        - pyparsing:         3.1.2
        - pytablewriter:     1.2.0
        - python-dateutil:   2.9.0.post0
        - python-dotenv:     1.0.1
        - python-json-logger: 2.0.7
        - python-multipart:  0.0.9
        - pytorch-lightning: 2.3.3
        - pytz:              2024.1
        - pyyaml:            6.0.1
        - pyzmq:             26.0.3
        - referencing:       0.35.1
        - regex:             2024.5.15
        - requests:          2.32.3
        - requests-oauthlib: 2.0.0
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rich:              13.7.1
        - rouge-score:       0.1.2
        - rpds-py:           0.19.0
        - rsa:               4.9
        - s3transfer:        0.10.2
        - sacrebleu:         2.4.2
        - safetensors:       0.4.3
        - scikit-learn:      1.3.2
        - scipy:             1.11.4
        - send2trash:        1.8.3
        - sentencepiece:     0.2.0
        - setuptools:        69.5.1
        - shellingham:       1.5.4
        - simple-term-menu:  1.6.4
        - six:               1.16.0
        - sniffio:           1.3.1
        - soupsieve:         2.5
        - sqlitedict:        2.1.0
        - stack-data:        0.6.3
        - starlette:         0.37.2
        - sympy:             1.13.0
        - tabledata:         1.3.3
        - tabulate:          0.9.0
        - tcolorpy:          0.1.6
        - tensorboard:       2.15.1
        - tensorboard-data-server: 0.7.2
        - termcolor:         2.4.0
        - terminado:         0.18.1
        - threadpoolctl:     3.5.0
        - tinycss2:          1.3.0
        - tokenizers:        0.19.1
        - tomli:             2.0.1
        - torch:             2.2.1+cu121
        - torchmetrics:      1.3.1
        - torchvision:       0.17.1+cu121
        - tornado:           6.4.1
        - tqdm:              4.66.4
        - tqdm-multiprocess: 0.0.11
        - traitlets:         5.14.3
        - transformers:      4.42.4
        - triton:            2.2.0
        - typepy:            1.3.2
        - typer:             0.12.3
        - types-python-dateutil: 2.9.0.20240316
        - typeshed-client:   2.7.0
        - typing-extensions: 4.12.2
        - tzdata:            2024.1
        - ujson:             5.10.0
        - uri-template:      1.3.0
        - urllib3:           2.2.2
        - uvicorn:           0.30.1
        - uvloop:            0.19.0
        - watchfiles:        0.22.0
        - wcwidth:           0.2.13
        - webcolors:         24.6.0
        - webencodings:      0.5.1
        - websocket-client:  1.8.0
        - websockets:        12.0
        - werkzeug:          3.0.3
        - wheel:             0.43.0
        - widgetsnbextension: 4.0.11
        - word2number:       1.1
        - xxhash:            3.4.1
        - yarl:              1.9.4
        - zstandard:         0.23.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - 
        - processor:         x86_64
        - python:            3.10.10
        - release:           5.15.0-1064-aws
        - version:           #70~20.04.1-Ubuntu SMP Fri Jun 14 15:42:13 UTC 2024

More info

Now, one thing to keep in mind is that we have the bitsandbytes version pinned to bitsandbytes==0.42.0 because 0.43 results in the following issue:

/teamspace/studios/this_studio/litgpt2/litgpt/generate/base.py:207: UserWarning: LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.
  warnings.warn(
Loading model 'checkpoints/microsoft/phi-2/lit_model.pth' with {'name': 'phi-2', 'hf_config': {'name': 'phi-2', 'org': 'microsoft'}, 'scale_embeddings': False, 'block_size': 2048, 'vocab_size': 50257, 'padding_multiple': 512, 'padded_vocab_size': 51200, 'n_layer': 32, 'n_head': 32, 'head_size': 80, 'n_embd': 2560, 'rotary_percentage': 0.4, 'parallel_residual': True, 'bias': True, 'lm_head_bias': True, 'n_query_groups': 32, 'shared_attention_norm': True, 'norm_class_name': 'LayerNorm', 'norm_eps': 1e-05, 'mlp_class_name': 'GptNeoxMLP', 'gelu_approximate': 'tanh', 'intermediate_size': 10240, 'rope_condense_ratio': 1, 'rope_base': 10000, 'n_expert': 0, 'n_expert_per_token': 0, 'rope_n_elem': 32}
Time to instantiate model: 0.24 seconds.
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt2/litgpt/__main__.py", line 71, in main
    CLI(parser_data)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
    return component(**cfg)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/litgpt2/litgpt/generate/base.py", line 252, in main
    model = fabric.setup_module(model)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 308, in setup_module
    module = self._move_model_to_device(model=module, optimizers=[])
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 976, in _move_model_to_device
    model = self.to_device(model)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 526, in to_device
    self._strategy.module_to_device(obj)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/strategies/single_device.py", line 59, in module_to_device
    module.to(self.root_device)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1152, in to
    return self._apply(convert)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1150, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 324, in to
    return self._quantize(device)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 289, in _quantize
    w_4bit, quant_state = bnb.functional.quantize_4bit(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1234, in quantize_4bit
    raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8

I wonder if these are issues that can or need to be addressed one at a time. Supporting lightning 2.3.3 with bitsandbytes==0.42.0 first to restore the litgpt quantization as is, and then see how we can upgrade to the most recent bitsandbytes version.

Any thoughts?

cc @awaelchli @carmocca

@rasbt rasbt added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jul 23, 2024
@rasbt
Copy link
Collaborator Author

rasbt commented Jul 23, 2024

github-actions bot added the ver: 2.2.x label

It's actually v2.3.3

@awaelchli awaelchli added checkpointing Related to checkpointing precision: bnb Bitsandbytes quantization and removed needs triage Waiting to be triaged by maintainers labels Jul 23, 2024
@awaelchli awaelchli self-assigned this Jul 23, 2024
@awaelchli awaelchli changed the title LitGPT QLoRA / bnb.nf4 quantization causes issues in recent PyTorch Lightning/Fabric versions Attribute error on _NotYetLoadedTensor loading checkpoint into quantized model Jul 23, 2024
@awaelchli awaelchli added this to the 2.3.x milestone Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing precision: bnb Bitsandbytes quantization ver: 2.2.x
Projects
None yet
2 participants