From 1c89a8bf51de4673931edb74835ccbb8046774f0 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 29 Jun 2024 15:22:03 +0000 Subject: [PATCH] disable padding in multipack Signed-off-by: Yu Chin Fabian Lim --- src/instructlab/training/main_ds.py | 2 +- src/instructlab/training/token_dataset.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index fe3b368c..f9b311a1 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -501,7 +501,7 @@ def main(args): avg_sample_len=dataset.get_lengths().mean(), effective_batch_size=args.effective_batch_size, max_batch_len_per_gpu=args.max_batch_len, - is_padding=not args.is_granite, + is_padding=not (args.is_granite or args.flatten_with_posid), dataset=dataset, pad_id=tokenizer.pad_token_id, seed=args.seed, diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index 57f8f6f6..9cbd21b7 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -97,6 +97,7 @@ def setup_dataloader( world_size = int(os.environ["WORLD_SIZE"]) lengths = dataset.get_lengths() + if sampler == "multipack": sampler = MultipackDistributedBatchSampler( batch_max_length=packing_max_batch_len, @@ -104,7 +105,7 @@ def setup_dataloader( num_replicas=world_size, rank=rank, seed=seed, - padding=not is_granite, + padding=not (is_granite or flatten_with_posid), ) sampler = {"batch_sampler": sampler} elif sampler == "distributed":