Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threaded MultipackDistributedDataloader with prefetched samples #759

Merged
merged 23 commits into from
Oct 26, 2023

Conversation

casper-hansen
Copy link
Collaborator

@casper-hansen casper-hansen commented Oct 21, 2023

Summary: We achieve a ~10-20% speed boost in multi-GPU training. We prefetch samples and put them into a queue with a max size of 1000. The DataLoader then yields from the queue.

This is what normal training with zero1 looks like now:

image

TODO:

  • This PR needs to be checked for correctness inside and out multiple times.
  • Probably remove all multi-threaded aspects and leave it for the future. num_threads > 1 makes performance much worse.

Speed testing - TinyLlama 1B

Must run accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml --prepare_ds_only to prepare dataset first in order to get proper benchmark values.

Main:

  • 1x H100: 163.571 samples/s
  • 2x H100: 389.07 samples/s (zero1)
  • 2x H100: 237.495 samples/s (zero2)
  • 2x H100: 424.269 samples/s (zero2, without offload_optimizer)
  • 2x A100: 199.862 samples/s (zero1)
  • 2x A100: 205.916 samples/s (zero2, without offload_optimizer)

PR:

  • 1x H100: 165.389 samples/s
  • 2x H100: 456.391 samples/s (zero1)
  • 2x H100: 245.831 samples/s (zero2)
  • 2x H100: 459.956 samples/s (zero2, without offload_optimizer)
  • 2x A100: 207.528 samples/s (zero1)
  • 2x A100: 169.178 samples/s (zero2)
  • 2x A100: 211.521 samples/s (zero2, without offload_optimizer)
Config
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: tatsu-lab/alpaca
    type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: false
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: false

warmup_steps: 100
eval_steps: 0.2
eval_table_size:
save_steps:
debug:
deepspeed: deepspeed/zero1.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"
Config
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: casperhansen/longalpaca_1k_test
    type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 32768
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:

wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed: deepspeed/zero1.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

The problem

GPU utilization on main looks like the following image. You see large dips in utilization. This is even more pronounced when you move to multi-GPU.

image

Zero1 vs Zero2

The main problem is the large spikes down when using zero2. This is likely from communication overhead. As we saw in the image at the start of the PR, zero1 is super smooth compared to zero2.

image

Main: Zero2 without offloading optimizer to CPU - same dips as we see on 1x GPU.

image

PR: Zero2 without offloading optimizer to CPU - just straight GPU utilization at 100%.

image

@casper-hansen casper-hansen changed the title Multithreading implementation Threaded MultipackDistributedDataloader with prefetched samples Oct 22, 2023
@casper-hansen casper-hansen marked this pull request as ready for review October 23, 2023 20:55
@casper-hansen
Copy link
Collaborator Author

casper-hansen commented Oct 25, 2023

@winglian Please have a look when you find some time. I have updated the start of threads so that the E2E tests do not hang and the startup is now faster than before. The problem was that CPU was being maxed out by the DataLoader while the model, deepspeed, and other things were loading - causing a massive slowdown.

image

@casper-hansen casper-hansen merged commit 05bd6f1 into axolotl-ai-cloud:main Oct 26, 2023
4 checks passed
mkeoliya pushed a commit to mkeoliya/axolotl that referenced this pull request Dec 15, 2023
…otl-ai-cloud#759)

* Multithreading implementation [WIP]

* Added benchmarking

* 35% increased throughput

* Memory pinning

* Start threads in init

* Correct print of samples

* Sleep if queue is full

* Remove pin_memory (worse)

* Simplify logic to one thread

* Remove benchmark

* Use deque for constant speed

* Formatting

* Formatting

* Formatting

* Formatting

* Rollback to use queue

* Fix multi-epoch training

* Add num epochs arg

* Start thread in __iter__

* Formatting

* Use is_alive correctly

* Simplify loading thread
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants