From 099eed02cfe1c180041fdb098e94dbe66ffa8c67 Mon Sep 17 00:00:00 2001 From: Casper Date: Sun, 22 Oct 2023 00:01:45 +0200 Subject: [PATCH 01/22] Multithreading implementation [WIP] --- src/axolotl/utils/dataloader.py | 48 ++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index d659c3d33..8deffd706 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -8,6 +8,8 @@ import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler +from queue import Queue +from threading import Thread, Lock LOG = logging.getLogger("axolotl.utils.dataloader") @@ -149,6 +151,8 @@ def __init__( packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, + num_threads: int = 4, + prefetch_max: int = 10, ): # Dataset self.dataset = dataset @@ -177,6 +181,48 @@ def __init__( self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.device_count = device_count + self.lock = Lock() + self.queue = Queue(maxsize=prefetch_max) + self.batches_indexed = set() + self.done_count = 0 + self.num_threads = num_threads + self.threads = [] + for i in range(num_threads): + thread = Thread(target=self._worker, args=(i,)) + thread.daemon = True + thread.start() + self.threads.append(thread) + + def _worker(self, worker_id): + LOG.warning(f"[WORKER:{worker_id}] WORKER RUNNING!!!") + for index, batch in enumerate(self._internal_batch_generator()): + with self.lock: + # O(1) complexity + if index not in self.batches_indexed: + self.batches_indexed.add(index) + should_add = True + else: + should_add = False + + if should_add: + self.queue.put(batch) + + # stop the queue when all workers are done + with self.lock: + self.done_count += 1 + if self.done_count == len(self.threads): + self.queue.put(None) + + def __iter__(self): + LOG.warning("STARTING WORKERS!!!") + + while True: + next_batch = self.queue.get() + if next_batch is None: + LOG.warning("DATALOADER STOPPED!!!") + break + yield next_batch + def generate_batches(self, set_stats=False): LOG.info("generating packed batches") if self.sampler: @@ -206,7 +252,7 @@ def generate_batches(self, set_stats=False): return batches, totseqs - def __iter__(self): + def _internal_batch_generator(self): if hasattr(self.sampler, "set_epoch"): new_epoch = self.sampler.epoch + 1 self.sampler.set_epoch(new_epoch) From eda9848c8ee04e00c8ff257ec19ee3ba5ab50f1e Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 10:19:45 +0000 Subject: [PATCH 02/22] Added benchmarking --- benchmark/dataloader.py | 74 +++++++++++++++++++++++++++++++++ src/axolotl/utils/dataloader.py | 23 +++++----- 2 files changed, 86 insertions(+), 11 deletions(-) create mode 100644 benchmark/dataloader.py diff --git a/benchmark/dataloader.py b/benchmark/dataloader.py new file mode 100644 index 000000000..da9e3abc9 --- /dev/null +++ b/benchmark/dataloader.py @@ -0,0 +1,74 @@ +import time +import torch +import numpy as np +from tqdm import tqdm +from axolotl.cli import load_datasets +from torch.utils.data import RandomSampler +from axolotl.utils.dict import DictDefault +from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.config import normalize_config +from transformers.data import default_data_collator +from axolotl.utils.dataloader import MultipackDistributedDataloader + +cfg = DictDefault( + { + "base_model": "openaccess-ai-collective/tiny-mistral", + "base_model_config": "openaccess-ai-collective/tiny-mistral", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": "./out", + "eval_steps": 10, + } +) + +normalize_config(cfg) +cli_args = TrainerCliArgs() +dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + +sampler = RandomSampler(dataset_meta.train_dataset) +dataloader = MultipackDistributedDataloader( + dataset=dataset_meta.train_dataset, + collate_fn=default_data_collator, + seq_max_length=cfg["sequence_len"], + batch_size=1, + sampler=None, + packing_efficiency_estimate=1.0, + sample_packing_seq_len_multiplier=1, + device_count=1, +) + +# Let workers warmup +time.sleep(2) + +# Measure throughput +timing = [] +num_iterations = dataloader.len_w_stats() +iter_dataset = iter(dataloader) + +for i in tqdm(range(num_iterations)): + t_start = time.time() + batch = next(iter_dataset) + inputs_ids = batch["input_ids"] + for _ in range(1000): torch.matmul(inputs_ids, inputs_ids.mT) + timing.append(time.time() - t_start) + +# Calculate throughput +throughput = 1 / np.median(timing) + +print(f"Throughput: {throughput:.2f} batches/sec") diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 8deffd706..23582b4ca 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -151,7 +151,7 @@ def __init__( packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, - num_threads: int = 4, + num_threads: int = 1, prefetch_max: int = 10, ): # Dataset @@ -186,12 +186,6 @@ def __init__( self.batches_indexed = set() self.done_count = 0 self.num_threads = num_threads - self.threads = [] - for i in range(num_threads): - thread = Thread(target=self._worker, args=(i,)) - thread.daemon = True - thread.start() - self.threads.append(thread) def _worker(self, worker_id): LOG.warning(f"[WORKER:{worker_id}] WORKER RUNNING!!!") @@ -215,6 +209,17 @@ def _worker(self, worker_id): def __iter__(self): LOG.warning("STARTING WORKERS!!!") + if hasattr(self.sampler, "set_epoch"): + new_epoch = self.sampler.epoch + 1 + self.sampler.set_epoch(new_epoch) + LOG.info(f"calling sampler.set_epoch({new_epoch})") + + self.threads = [] + for i in range(self.num_threads): + thread = Thread(target=self._worker, args=(i,)) + thread.daemon = True + thread.start() + self.threads.append(thread) while True: next_batch = self.queue.get() @@ -253,10 +258,6 @@ def generate_batches(self, set_stats=False): return batches, totseqs def _internal_batch_generator(self): - if hasattr(self.sampler, "set_epoch"): - new_epoch = self.sampler.epoch + 1 - self.sampler.set_epoch(new_epoch) - LOG.info(f"calling sampler.set_epoch({new_epoch})") all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() len_remaining = self._len_est() From 8cb60b64f8d0d88a02333c8534fd12ec1dffd840 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 17:29:31 +0000 Subject: [PATCH 03/22] 35% increased throughput --- src/axolotl/utils/dataloader.py | 41 +++++++++++++++------------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 23582b4ca..a341f5e9c 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -9,7 +9,7 @@ import numpy as np from torch.utils.data import DistributedSampler, Sampler from queue import Queue -from threading import Thread, Lock +from threading import Thread LOG = logging.getLogger("axolotl.utils.dataloader") @@ -151,8 +151,8 @@ def __init__( packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, - num_threads: int = 1, - prefetch_max: int = 10, + num_threads: int = 4, + prefetch_max: int = 1000, ): # Dataset self.dataset = dataset @@ -181,31 +181,27 @@ def __init__( self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.device_count = device_count - self.lock = Lock() self.queue = Queue(maxsize=prefetch_max) self.batches_indexed = set() self.done_count = 0 self.num_threads = num_threads + + # thread 0 gets batch 0, thread 1 gets batch 1 + # thread 0 gets batch 2, thread 1 gets batch 3 + # etc ... + self.worker_indices = [set(range(0, self.len_w_stats(), self.num_threads)) for i in range(self.num_threads)] def _worker(self, worker_id): - LOG.warning(f"[WORKER:{worker_id}] WORKER RUNNING!!!") + worker_indices = self.worker_indices[worker_id] + LOG.warning(f"[WORKER:{worker_id}] RUNNING - {len(worker_indices)} batches") for index, batch in enumerate(self._internal_batch_generator()): - with self.lock: - # O(1) complexity - if index not in self.batches_indexed: - self.batches_indexed.add(index) - should_add = True - else: - should_add = False - - if should_add: + if index in worker_indices: self.queue.put(batch) # stop the queue when all workers are done - with self.lock: - self.done_count += 1 - if self.done_count == len(self.threads): - self.queue.put(None) + self.done_count += 1 + if self.done_count == len(self.threads): + self.queue.put(None) def __iter__(self): LOG.warning("STARTING WORKERS!!!") @@ -222,11 +218,12 @@ def __iter__(self): self.threads.append(thread) while True: - next_batch = self.queue.get() - if next_batch is None: - LOG.warning("DATALOADER STOPPED!!!") + item = self.queue.get() + if item is None: break - yield next_batch + yield item + + LOG.warning("DATALOADER FINISHED!!!") def generate_batches(self, set_stats=False): LOG.info("generating packed batches") From 2454b6ac1ab92ee7a7d894fc29f805aad1b0d2c6 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 18:15:32 +0000 Subject: [PATCH 04/22] Memory pinning --- src/axolotl/utils/dataloader.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index a341f5e9c..1653321e5 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -4,7 +4,7 @@ import logging import math from typing import Any, Callable, List, Union - +import torch import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler @@ -153,6 +153,7 @@ def __init__( device_count: int = 1, num_threads: int = 4, prefetch_max: int = 1000, + pin_memory: bool = True, ): # Dataset self.dataset = dataset @@ -185,6 +186,7 @@ def __init__( self.batches_indexed = set() self.done_count = 0 self.num_threads = num_threads + self.pin_memory = pin_memory # thread 0 gets batch 0, thread 1 gets batch 1 # thread 0 gets batch 2, thread 1 gets batch 3 @@ -194,9 +196,11 @@ def __init__( def _worker(self, worker_id): worker_indices = self.worker_indices[worker_id] LOG.warning(f"[WORKER:{worker_id}] RUNNING - {len(worker_indices)} batches") - for index, batch in enumerate(self._internal_batch_generator()): + for index, sample in enumerate(self._internal_batch_generator()): if index in worker_indices: - self.queue.put(batch) + if self.pin_memory: + sample = {k: torch.as_tensor(v).pin_memory() for k,v in sample.items()} + self.queue.put(sample) # stop the queue when all workers are done self.done_count += 1 From 55fc5abb67194580c4b0d3600a8e5a129fc7664c Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 19:19:19 +0000 Subject: [PATCH 05/22] Start threads in init --- src/axolotl/utils/dataloader.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 1653321e5..dc6e7f5b4 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -151,7 +151,7 @@ def __init__( packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, - num_threads: int = 4, + num_threads: int = 1, prefetch_max: int = 1000, pin_memory: bool = True, ): @@ -192,6 +192,13 @@ def __init__( # thread 0 gets batch 2, thread 1 gets batch 3 # etc ... self.worker_indices = [set(range(0, self.len_w_stats(), self.num_threads)) for i in range(self.num_threads)] + + self.threads = [] + for i in range(self.num_threads): + thread = Thread(target=self._worker, args=(i,)) + thread.daemon = True + thread.start() + self.threads.append(thread) def _worker(self, worker_id): worker_indices = self.worker_indices[worker_id] @@ -208,18 +215,12 @@ def _worker(self, worker_id): self.queue.put(None) def __iter__(self): - LOG.warning("STARTING WORKERS!!!") if hasattr(self.sampler, "set_epoch"): new_epoch = self.sampler.epoch + 1 self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") - self.threads = [] - for i in range(self.num_threads): - thread = Thread(target=self._worker, args=(i,)) - thread.daemon = True - thread.start() - self.threads.append(thread) + self.done_count = 0 while True: item = self.queue.get() From b9f2e124f34788979324fe2126c4af7aef7b37cf Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 20:02:11 +0000 Subject: [PATCH 06/22] Correct print of samples --- src/axolotl/utils/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index dc6e7f5b4..d2e5821dc 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -202,7 +202,7 @@ def __init__( def _worker(self, worker_id): worker_indices = self.worker_indices[worker_id] - LOG.warning(f"[WORKER:{worker_id}] RUNNING - {len(worker_indices)} batches") + LOG.warning(f"[WORKER:{worker_id}] RUNNING - {len(worker_indices)*self.batch_size} samples") for index, sample in enumerate(self._internal_batch_generator()): if index in worker_indices: if self.pin_memory: From 5d9959a966618176880bf00a573681f5e2740f40 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 20:23:18 +0000 Subject: [PATCH 07/22] Sleep if queue is full --- src/axolotl/utils/dataloader.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index d2e5821dc..a76453f7f 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -10,7 +10,7 @@ from torch.utils.data import DistributedSampler, Sampler from queue import Queue from threading import Thread - +import time LOG = logging.getLogger("axolotl.utils.dataloader") @@ -182,7 +182,8 @@ def __init__( self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.device_count = device_count - self.queue = Queue(maxsize=prefetch_max) + # maxsize is maximum number of samples in queue + self.queue = Queue(maxsize=prefetch_max) self.batches_indexed = set() self.done_count = 0 self.num_threads = num_threads @@ -207,6 +208,12 @@ def _worker(self, worker_id): if index in worker_indices: if self.pin_memory: sample = {k: torch.as_tensor(v).pin_memory() for k,v in sample.items()} + + while True: + if self.queue.full(): + time.sleep(1) + else: + break self.queue.put(sample) # stop the queue when all workers are done From 6c32064a5f42531ac6abe2ba7bad89820129eca1 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 22 Oct 2023 21:17:49 +0000 Subject: [PATCH 08/22] Remove pin_memory (worse) --- src/axolotl/utils/dataloader.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index a76453f7f..cf423241c 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -153,7 +153,6 @@ def __init__( device_count: int = 1, num_threads: int = 1, prefetch_max: int = 1000, - pin_memory: bool = True, ): # Dataset self.dataset = dataset @@ -187,7 +186,6 @@ def __init__( self.batches_indexed = set() self.done_count = 0 self.num_threads = num_threads - self.pin_memory = pin_memory # thread 0 gets batch 0, thread 1 gets batch 1 # thread 0 gets batch 2, thread 1 gets batch 3 @@ -206,9 +204,6 @@ def _worker(self, worker_id): LOG.warning(f"[WORKER:{worker_id}] RUNNING - {len(worker_indices)*self.batch_size} samples") for index, sample in enumerate(self._internal_batch_generator()): if index in worker_indices: - if self.pin_memory: - sample = {k: torch.as_tensor(v).pin_memory() for k,v in sample.items()} - while True: if self.queue.full(): time.sleep(1) From 121f9968c817f723c8549d6638e5db26be320c8d Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 23 Oct 2023 12:34:05 +0000 Subject: [PATCH 09/22] Simplify logic to one thread --- src/axolotl/utils/dataloader.py | 53 ++++++++++----------------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index cf423241c..79c13ea6d 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -4,7 +4,7 @@ import logging import math from typing import Any, Callable, List, Union -import torch + import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler @@ -151,7 +151,6 @@ def __init__( packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, - num_threads: int = 1, prefetch_max: int = 1000, ): # Dataset @@ -182,47 +181,27 @@ def __init__( self.device_count = device_count # maxsize is maximum number of samples in queue - self.queue = Queue(maxsize=prefetch_max) - self.batches_indexed = set() - self.done_count = 0 - self.num_threads = num_threads - - # thread 0 gets batch 0, thread 1 gets batch 1 - # thread 0 gets batch 2, thread 1 gets batch 3 - # etc ... - self.worker_indices = [set(range(0, self.len_w_stats(), self.num_threads)) for i in range(self.num_threads)] - - self.threads = [] - for i in range(self.num_threads): - thread = Thread(target=self._worker, args=(i,)) - thread.daemon = True - thread.start() - self.threads.append(thread) + self.queue = Queue(maxsize=prefetch_max) + self.thread = Thread(target=self._worker, daemon=True) + self.thread.start() - def _worker(self, worker_id): - worker_indices = self.worker_indices[worker_id] - LOG.warning(f"[WORKER:{worker_id}] RUNNING - {len(worker_indices)*self.batch_size} samples") - for index, sample in enumerate(self._internal_batch_generator()): - if index in worker_indices: - while True: - if self.queue.full(): - time.sleep(1) - else: - break - self.queue.put(sample) - - # stop the queue when all workers are done - self.done_count += 1 - if self.done_count == len(self.threads): - self.queue.put(None) + def _worker(self): + for sample in self._internal_batch_generator(): + while True: + if self.queue.full(): + time.sleep(1) + else: + break + self.queue.put(sample) + + # stop the queue when worker is done + self.queue.put(None) def __iter__(self): if hasattr(self.sampler, "set_epoch"): new_epoch = self.sampler.epoch + 1 self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") - - self.done_count = 0 while True: item = self.queue.get() @@ -230,7 +209,7 @@ def __iter__(self): break yield item - LOG.warning("DATALOADER FINISHED!!!") + LOG.info("DATALOADER FINISHED") def generate_batches(self, set_stats=False): LOG.info("generating packed batches") From d6890bd203702ec28cd58cf3acd158358d51024b Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 23 Oct 2023 12:48:48 +0000 Subject: [PATCH 10/22] Remove benchmark --- benchmark/dataloader.py | 74 ----------------------------------------- 1 file changed, 74 deletions(-) delete mode 100644 benchmark/dataloader.py diff --git a/benchmark/dataloader.py b/benchmark/dataloader.py deleted file mode 100644 index da9e3abc9..000000000 --- a/benchmark/dataloader.py +++ /dev/null @@ -1,74 +0,0 @@ -import time -import torch -import numpy as np -from tqdm import tqdm -from axolotl.cli import load_datasets -from torch.utils.data import RandomSampler -from axolotl.utils.dict import DictDefault -from axolotl.common.cli import TrainerCliArgs -from axolotl.utils.config import normalize_config -from transformers.data import default_data_collator -from axolotl.utils.dataloader import MultipackDistributedDataloader - -cfg = DictDefault( - { - "base_model": "openaccess-ai-collective/tiny-mistral", - "base_model_config": "openaccess-ai-collective/tiny-mistral", - "flash_attention": True, - "sample_packing": True, - "sequence_len": 1024, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "output_dir": "./out", - "eval_steps": 10, - } -) - -normalize_config(cfg) -cli_args = TrainerCliArgs() -dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - -sampler = RandomSampler(dataset_meta.train_dataset) -dataloader = MultipackDistributedDataloader( - dataset=dataset_meta.train_dataset, - collate_fn=default_data_collator, - seq_max_length=cfg["sequence_len"], - batch_size=1, - sampler=None, - packing_efficiency_estimate=1.0, - sample_packing_seq_len_multiplier=1, - device_count=1, -) - -# Let workers warmup -time.sleep(2) - -# Measure throughput -timing = [] -num_iterations = dataloader.len_w_stats() -iter_dataset = iter(dataloader) - -for i in tqdm(range(num_iterations)): - t_start = time.time() - batch = next(iter_dataset) - inputs_ids = batch["input_ids"] - for _ in range(1000): torch.matmul(inputs_ids, inputs_ids.mT) - timing.append(time.time() - t_start) - -# Calculate throughput -throughput = 1 / np.median(timing) - -print(f"Throughput: {throughput:.2f} batches/sec") From ac65ba9c85d574567c52a059961d08c1b01465d9 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 23 Oct 2023 20:45:40 +0000 Subject: [PATCH 11/22] Use deque for constant speed --- src/axolotl/utils/dataloader.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 79c13ea6d..0bb7af81d 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -8,7 +8,7 @@ import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler -from queue import Queue +from collections import deque from threading import Thread import time LOG = logging.getLogger("axolotl.utils.dataloader") @@ -181,21 +181,22 @@ def __init__( self.device_count = device_count # maxsize is maximum number of samples in queue - self.queue = Queue(maxsize=prefetch_max) + self.prefetch_max = prefetch_max + self.queue = deque(maxlen=prefetch_max) self.thread = Thread(target=self._worker, daemon=True) self.thread.start() def _worker(self): for sample in self._internal_batch_generator(): while True: - if self.queue.full(): + if len(self.queue) == self.prefetch_max: time.sleep(1) else: break - self.queue.put(sample) + self.queue.append(sample) # stop the queue when worker is done - self.queue.put(None) + self.queue.append(None) def __iter__(self): if hasattr(self.sampler, "set_epoch"): @@ -204,7 +205,7 @@ def __iter__(self): LOG.info(f"calling sampler.set_epoch({new_epoch})") while True: - item = self.queue.get() + item = self.queue.popleft() if item is None: break yield item From f3f1ba0ca0a479b0ce7db43bc8495ad84a6ad2a9 Mon Sep 17 00:00:00 2001 From: Casper Date: Mon, 23 Oct 2023 22:56:41 +0200 Subject: [PATCH 12/22] Formatting --- src/axolotl/utils/dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 0bb7af81d..e1b34850a 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -8,9 +8,10 @@ import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler +import time from collections import deque from threading import Thread -import time + LOG = logging.getLogger("axolotl.utils.dataloader") From 2d6bd64ba22333abc776e92a76d1ae6b5d2c7419 Mon Sep 17 00:00:00 2001 From: Casper Date: Mon, 23 Oct 2023 22:58:12 +0200 Subject: [PATCH 13/22] Formatting --- src/axolotl/utils/dataloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index e1b34850a..30c6d1097 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -186,7 +186,7 @@ def __init__( self.queue = deque(maxlen=prefetch_max) self.thread = Thread(target=self._worker, daemon=True) self.thread.start() - + def _worker(self): for sample in self._internal_batch_generator(): while True: @@ -198,7 +198,7 @@ def _worker(self): # stop the queue when worker is done self.queue.append(None) - + def __iter__(self): if hasattr(self.sampler, "set_epoch"): new_epoch = self.sampler.epoch + 1 @@ -210,7 +210,7 @@ def __iter__(self): if item is None: break yield item - + LOG.info("DATALOADER FINISHED") def generate_batches(self, set_stats=False): From 4ff2903655bebf6d475b59e912631bedaa4ad72d Mon Sep 17 00:00:00 2001 From: Casper Date: Mon, 23 Oct 2023 23:00:32 +0200 Subject: [PATCH 14/22] Formatting --- src/axolotl/utils/dataloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 30c6d1097..f33d55374 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -3,14 +3,14 @@ import itertools import logging import math +import time +from collections import deque +from threading import Thread from typing import Any, Callable, List, Union import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler -import time -from collections import deque -from threading import Thread LOG = logging.getLogger("axolotl.utils.dataloader") From e026a54853d26e0bf14e51160ff9bf91e9e04c10 Mon Sep 17 00:00:00 2001 From: Casper Date: Mon, 23 Oct 2023 23:03:58 +0200 Subject: [PATCH 15/22] Formatting --- src/axolotl/utils/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index f33d55374..3f22a7d69 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -183,7 +183,7 @@ def __init__( # maxsize is maximum number of samples in queue self.prefetch_max = prefetch_max - self.queue = deque(maxlen=prefetch_max) + self.queue: deque = deque(maxlen=prefetch_max) self.thread = Thread(target=self._worker, daemon=True) self.thread.start() From 2bb3a8e775be5e5da4717d6186abfddf098c4879 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 23 Oct 2023 21:23:55 +0000 Subject: [PATCH 16/22] Rollback to use queue --- src/axolotl/utils/dataloader.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 3f22a7d69..4cc5cc696 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -4,7 +4,7 @@ import logging import math import time -from collections import deque +from queue import Queue from threading import Thread from typing import Any, Callable, List, Union @@ -183,21 +183,21 @@ def __init__( # maxsize is maximum number of samples in queue self.prefetch_max = prefetch_max - self.queue: deque = deque(maxlen=prefetch_max) + self.queue: Queue = Queue(maxsize=prefetch_max) self.thread = Thread(target=self._worker, daemon=True) self.thread.start() def _worker(self): for sample in self._internal_batch_generator(): while True: - if len(self.queue) == self.prefetch_max: + if self.queue.full(): time.sleep(1) else: break - self.queue.append(sample) + self.queue.put(sample) # stop the queue when worker is done - self.queue.append(None) + self.queue.put(None) def __iter__(self): if hasattr(self.sampler, "set_epoch"): @@ -206,7 +206,8 @@ def __iter__(self): LOG.info(f"calling sampler.set_epoch({new_epoch})") while True: - item = self.queue.popleft() + item = self.queue.get() + if item is None: break yield item From d5d91dbc78da62ab42316c79b4161584acd64d39 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 23 Oct 2023 21:52:15 +0000 Subject: [PATCH 17/22] Fix multi-epoch training --- src/axolotl/utils/dataloader.py | 24 +++++++++++++----------- src/axolotl/utils/trainer.py | 7 ++++++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 4cc5cc696..8722030d5 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -153,6 +153,7 @@ def __init__( sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, prefetch_max: int = 1000, + num_epochs: int = 1, ): # Dataset self.dataset = dataset @@ -171,6 +172,7 @@ def __init__( self.seq_max_length = seq_max_length self.batch_max_length = batch_size * seq_max_length self.collate_fn = collate_fn + self.num_epochs = num_epochs self.num_replicas = 1 self.rank = 0 @@ -188,16 +190,18 @@ def __init__( self.thread.start() def _worker(self): - for sample in self._internal_batch_generator(): - while True: - if self.queue.full(): - time.sleep(1) - else: - break - self.queue.put(sample) + LOG.info(f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}") + for epoch in range(self.num_epochs): + for sample in self._internal_batch_generator(): + while True: + if self.queue.full(): + time.sleep(1) + else: + break + self.queue.put(sample) - # stop the queue when worker is done - self.queue.put(None) + # stop the queue when epoch is done + self.queue.put(None) def __iter__(self): if hasattr(self.sampler, "set_epoch"): @@ -212,8 +216,6 @@ def __iter__(self): break yield item - LOG.info("DATALOADER FINISHED") - def generate_batches(self, set_stats=False): LOG.info("generating packed batches") if self.sampler: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 820202b80..d6262a85b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -178,7 +178,8 @@ class AxolotlTrainer(Trainer): args = None # type: AxolotlTrainingArguments - def __init__(self, *args, bench_data_collator=None, **kwargs): + def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): + self.num_epochs = num_epochs self.bench_data_collator = bench_data_collator super().__init__(*args, **kwargs) @@ -249,6 +250,7 @@ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataload packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=self.num_epochs, ) ) return super().get_train_dataloader() @@ -272,6 +274,7 @@ def get_eval_dataloader( packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.eval_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=self.num_epochs, ) ) return super().get_eval_dataloader(eval_dataset) @@ -501,6 +504,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): packing_efficiency_estimate=cfg.sample_packing_eff_est, sample_packing_seq_len_multiplier=cfg.micro_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=cfg.num_epochs, ) data_loader_len = data_loader.len_w_stats() actual_eff = data_loader.efficiency() @@ -771,6 +775,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **data_collator_kwargs, ), callbacks=callbacks, + num_epochs=cfg.num_epochs, **trainer_kwargs, ) From 4ee158481146978aaff9797d573b8f5973e6ee70 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 25 Oct 2023 10:28:45 +0200 Subject: [PATCH 18/22] Add num epochs arg --- src/axolotl/core/trainer_builder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 00a1a0c67..55a1764fc 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer): args = None # type: AxolotlTrainingArguments - def __init__(self, *args, bench_data_collator=None, **kwargs): + def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): + self.num_epochs = num_epochs self.bench_data_collator = bench_data_collator super().__init__(*args, **kwargs) @@ -182,6 +183,7 @@ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataload packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=self.num_epochs, ) ) return super().get_train_dataloader() @@ -205,6 +207,7 @@ def get_eval_dataloader( packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.eval_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=self.num_epochs, ) ) return super().get_eval_dataloader(eval_dataset) @@ -680,6 +683,7 @@ def build(self, total_num_steps): **data_collator_kwargs, ), callbacks=self.get_callbacks(), + num_epochs=self.cfg.num_epochs, **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) From c564819b9d0edb90ec95a9d44407fe0216a217b6 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 25 Oct 2023 10:32:45 +0200 Subject: [PATCH 19/22] Start thread in __iter__ --- src/axolotl/utils/dataloader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 8722030d5..f747dcac0 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -187,10 +187,11 @@ def __init__( self.prefetch_max = prefetch_max self.queue: Queue = Queue(maxsize=prefetch_max) self.thread = Thread(target=self._worker, daemon=True) - self.thread.start() def _worker(self): - LOG.info(f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}") + LOG.info( + f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}" + ) for epoch in range(self.num_epochs): for sample in self._internal_batch_generator(): while True: @@ -208,6 +209,9 @@ def __iter__(self): new_epoch = self.sampler.epoch + 1 self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") + + if not self.thread.is_alive: + self.thread.start() while True: item = self.queue.get() From 62bd93b33533552588f2e08307d5258e95cad528 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 25 Oct 2023 10:34:39 +0200 Subject: [PATCH 20/22] Formatting --- src/axolotl/utils/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index f747dcac0..2c3371e46 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -209,7 +209,7 @@ def __iter__(self): new_epoch = self.sampler.epoch + 1 self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") - + if not self.thread.is_alive: self.thread.start() From 78823d0abf410e724d7f066ee403e929a2a218b2 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 25 Oct 2023 10:39:50 +0200 Subject: [PATCH 21/22] Use is_alive correctly --- src/axolotl/utils/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 2c3371e46..88e36eca8 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -210,7 +210,7 @@ def __iter__(self): self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") - if not self.thread.is_alive: + if not self.thread.is_alive(): self.thread.start() while True: From e1bd784dbb0811b16bfce3af0306475fe363fcf7 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Wed, 25 Oct 2023 09:08:32 +0000 Subject: [PATCH 22/22] Simplify loading thread --- src/axolotl/utils/dataloader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 88e36eca8..54c95db78 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -186,7 +186,7 @@ def __init__( # maxsize is maximum number of samples in queue self.prefetch_max = prefetch_max self.queue: Queue = Queue(maxsize=prefetch_max) - self.thread = Thread(target=self._worker, daemon=True) + self.thread = None def _worker(self): LOG.info( @@ -210,7 +210,8 @@ def __iter__(self): self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") - if not self.thread.is_alive(): + if self.thread is None: + self.thread = Thread(target=self._worker, daemon=True) self.thread.start() while True: