Skip to content

Commit

Permalink
Threaded MultipackDistributedDataloader with prefetched samples (#759)
Browse files Browse the repository at this point in the history
* Multithreading implementation [WIP]

* Added benchmarking

* 35% increased throughput

* Memory pinning

* Start threads in init

* Correct print of samples

* Sleep if queue is full

* Remove pin_memory (worse)

* Simplify logic to one thread

* Remove benchmark

* Use deque for constant speed

* Formatting

* Formatting

* Formatting

* Formatting

* Rollback to use queue

* Fix multi-epoch training

* Add num epochs arg

* Start thread in __iter__

* Formatting

* Use is_alive correctly

* Simplify loading thread
  • Loading branch information
casper-hansen committed Oct 26, 2023
1 parent 20aa4b5 commit 05bd6f1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
6 changes: 5 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer):

args = None # type: AxolotlTrainingArguments

def __init__(self, *args, bench_data_collator=None, **kwargs):
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -182,6 +183,7 @@ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataload
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
)
)
return super().get_train_dataloader()
Expand All @@ -205,6 +207,7 @@ def get_eval_dataloader(
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
)
)
return super().get_eval_dataloader(eval_dataset)
Expand Down Expand Up @@ -680,6 +683,7 @@ def build(self, total_num_steps):
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
Expand Down
50 changes: 45 additions & 5 deletions src/axolotl/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import itertools
import logging
import math
import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, List, Union

import numba
Expand Down Expand Up @@ -149,6 +152,8 @@ def __init__(
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
prefetch_max: int = 1000,
num_epochs: int = 1,
):
# Dataset
self.dataset = dataset
Expand All @@ -167,6 +172,7 @@ def __init__(
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_epochs = num_epochs

self.num_replicas = 1
self.rank = 0
Expand All @@ -177,6 +183,44 @@ def __init__(
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count

# maxsize is maximum number of samples in queue
self.prefetch_max = prefetch_max
self.queue: Queue = Queue(maxsize=prefetch_max)
self.thread = None

def _worker(self):
LOG.info(
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
)
for epoch in range(self.num_epochs):
for sample in self._internal_batch_generator():
while True:
if self.queue.full():
time.sleep(1)
else:
break
self.queue.put(sample)

# stop the queue when epoch is done
self.queue.put(None)

def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")

if self.thread is None:
self.thread = Thread(target=self._worker, daemon=True)
self.thread.start()

while True:
item = self.queue.get()

if item is None:
break
yield item

def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
Expand Down Expand Up @@ -206,11 +250,7 @@ def generate_batches(self, set_stats=False):

return batches, totseqs

def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
def _internal_batch_generator(self):
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=cfg.num_epochs,
)
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
Expand Down

0 comments on commit 05bd6f1

Please sign in to comment.