Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jamba & Deepspeed zero-3 #4300

Open
1 task done
lwang2070 opened this issue Jun 15, 2024 · 5 comments
Open
1 task done

Jamba & Deepspeed zero-3 #4300

lwang2070 opened this issue Jun 15, 2024 · 5 comments
Labels
pending This problem is yet to be addressed

Comments

@lwang2070
Copy link

Reminder

  • I have read the README and searched the existing issues.

System Info

  • llamafactory version: 0.8.2.dev0
  • Platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
  • Python version: 3.11.8
  • PyTorch version: 2.3.1+cu121 (GPU)
  • Transformers version: 4.41.2
  • Datasets version: 2.18.0
  • Accelerate version: 0.31.0
  • PEFT version: 0.11.1
  • TRL version: 0.9.4
  • GPU type: NVIDIA H800
  • DeepSpeed version: 0.14.2

Reproduction

bash llamafactory-cli train jamba_full_sft.yaml

YAML 配置如下:

### model
model_name_or_path: /juicefs-algorithm/models/nlp/huggingface/jamba/Jamba-v0.1

### method
stage: sft
do_train: true
finetuning_type: full

### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json

### dataset
dataset: identity,alpaca_en_demo
template: atom
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: logs/jamba
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
gradient_checkpointing: true

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

Expected behavior

用deepspeed zero-3运行jamba模型。

Others

报错如下:

[rank5]: RuntimeError: tracing error at step 470: 
[rank5]: module id: 489, training: True
[rank5]: expected the next 1 parameters in the parameter fetch queue to be ({'id': 'name=model.layers.31.mamba.out_proj.weight id=1192', 'status': 'AVAILABLE', 'numel': 33554432, 'ds_numel': 33554432, 'shape': (4096, 8192), 'ds_shape': (4096, 8192), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {489}, 'ds_tensor.shape': torch.Size([4194304])},) 
[rank5]: but got 
[rank5]:  ({'id': 'name=model.layers.31.mamba.dt_proj.bias id=1191', 'status': 'AVAILABLE', 'numel': 8192, 'ds_numel': 8192, 'shape': (8192,), 'ds_shape': (8192,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {483}, 'ds_tensor.shape': torch.Size([1024])},).
@github-actions github-actions bot added the pending This problem is yet to be addressed label Jun 15, 2024
@hiyouga
Copy link
Owner

hiyouga commented Jun 15, 2024

Try using the hf-compatible version of jamba: TechxGenus/Jamba-v0.1-hf

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Jun 15, 2024
@hiyouga hiyouga closed this as completed Jun 15, 2024
@lwang2070
Copy link
Author

Nope, get the exact same error:

[rank0]: expected the next 1 parameters in the parameter fetch queue to be ({'id': 'name=model.layers.31.mamba.out_proj.weight id=1192', 'status': 'AVAILABLE', 'numel': 33554432, 'ds_numel': 33554432, 'shape': (4096, 8192), 'ds_shape': (4096, 8192), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {489}, 'ds_tensor.shape': torch.Size([4194304])},) 
[rank0]: but got 
[rank0]:  ({'id': 'name=model.layers.31.mamba.dt_proj.bias id=1191', 'status': 'AVAILABLE', 'numel': 8192, 'ds_numel': 8192, 'shape': (8192,), 'ds_shape': (8192,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {483}, 'ds_tensor.shape': torch.Size([1024])},).

To confirm, this is the new config:

{
  "architectures": [
    "JambaForCausalLM"
  ],
  "attention_dropout": 0.0,
  "attn_layer_offset": 4,
  "attn_layer_period": 8,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "expert_layer_offset": 1,
  "expert_layer_period": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "mamba_conv_bias": true,
  "mamba_d_conv": 4,
  "mamba_d_state": 16,
  "mamba_dt_rank": 256,
  "mamba_expand": 2,
  "mamba_proj_bias": false,
  "max_position_embeddings": 262144,
  "model_type": "jamba",
  "num_attention_heads": 32,
  "num_experts": 16,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "num_logits_to_keep": 1,
  "output_router_logits": false,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "router_aux_loss_coef": 0.001,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.40.0",
  "use_cache": true,
  "use_mamba_kernels": true,
  "vocab_size": 65536
}

@hiyouga
Copy link
Owner

hiyouga commented Jun 15, 2024

The full traceback is needed for debugging

@lwang2070
Copy link
Author

Here is the full traceback:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/juicefs-algorithm/workspace/nlp/li_wang/llama/src/llamafactory/launcher.py", line 9, in <module>
[rank7]:     launch()
[rank7]:   File "/juicefs-algorithm/workspace/nlp/li_wang/llama/src/llamafactory/launcher.py", line 5, in launch
[rank7]:     run_exp()
[rank7]:   File "/juicefs-algorithm/workspace/nlp/li_wang/llama/src/llamafactory/train/tuner.py", line 40, in run_exp
[rank7]:     run_sft(
[rank7]:   File "/juicefs-algorithm/workspace/nlp/li_wang/llama/src/llamafactory/train/sft/workflow.py", line 98, in run_sft
[rank7]:     train_result = trainer.train(
[rank7]:                    ^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py", line 1885, in train
[rank7]:     return inner_training_loop(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
[rank7]:     tr_loss_step = self.training_step(model, inputs)
[rank7]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py", line 3250, in training_step
[rank7]:     self.accelerator.backward(loss)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/accelerate/accelerator.py", line 2126, in backward
[rank7]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
[rank7]:     self.engine.backward(loss, **kwargs)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
[rank7]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/zero/stage3.py", line 2213, in backward
[rank7]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank7]:     scaled_loss.backward(retain_graph=retain_graph)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/_tensor.py", line 525, in backward
[rank7]:     torch.autograd.backward(
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank7]:     _engine_run_backward(
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank7]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/autograd/function.py", line 301, in apply
[rank7]:     return user_fn(self, *args)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 303, in backward
[rank7]:     outputs = ctx.run_function(*detached_inputs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/transformers/models/jamba/modeling_jamba.py", line 1202, in forward
[rank7]:     hidden_states = self.mamba(
[rank7]:                     ^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/transformers/models/jamba/modeling_jamba.py", line 991, in forward
[rank7]:     return self.cuda_kernels_forward(hidden_states, cache_params)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/transformers/models/jamba/modeling_jamba.py", line 901, in cuda_kernels_forward
[rank7]:     contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
[rank7]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1571, in _call_impl
[rank7]:     args_result = hook(self, args)
[rank7]:                   ^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 278, in _pre_forward_module_hook
[rank7]:     self.pre_sub_module_forward_function(module)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 452, in pre_sub_module_forward_function
[rank7]:     param_coordinator.fetch_sub_module(sub_module, forward=True)
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank7]:     return fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/li_wang/local/conda/envs/llama/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 338, in fetch_sub_module
[rank7]:     raise RuntimeError(
[rank7]: RuntimeError: tracing error at step 470: 
[rank7]: module id: 489, training: True
[rank7]: expected the next 1 parameters in the parameter fetch queue to be ({'id': 'name=model.layers.31.mamba.out_proj.weight id=1192', 'status': 'AVAILABLE', 'numel': 33554432, 'ds_numel': 33554432, 'shape': (4096, 8192), 'ds_shape': (4096, 8192), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {489}, 'ds_tensor.shape': torch.Size([4194304])},) 
[rank7]: but got 
[rank7]:  ({'id': 'name=model.layers.31.mamba.dt_proj.bias id=1191', 'status': 'AVAILABLE', 'numel': 8192, 'ds_numel': 8192, 'shape': (8192,), 'ds_shape': (8192,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {483}, 'ds_tensor.shape': torch.Size([1024])},).

@hiyouga hiyouga reopened this Jun 17, 2024
@hiyouga hiyouga added pending This problem is yet to be addressed and removed solved This problem has been already solved labels Jun 17, 2024
@coranholmes
Copy link

请问你这个问题解决了吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

3 participants