From c6bffd46e64c00d63fd0426690aa407bc228057d Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 19 Jun 2024 20:59:29 -0400 Subject: [PATCH] Multipack padding fix (#19) For non-dolomite models, packing_max_batch_len was being calculated incorrectly. This calculates the necessary increase to make sure that average batch size is similar to the specified effective_batch_size. --------- Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/main_ds.py | 4 + src/instructlab/training/multipack_sampler.py | 163 +++++++++++++++++- src/instructlab/training/token_dataset.py | 91 +--------- src/instructlab/training/utils.py | 89 ++++++++++ 4 files changed, 253 insertions(+), 94 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index c40606b2..eeb0c077 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -461,6 +461,10 @@ def main(args): avg_sample_len=dataset.get_lengths().mean(), effective_batch_size=args.effective_batch_size, max_batch_len_per_gpu=args.max_batch_len, + is_padding=not args.is_granite, + dataset=dataset, + pad_id=tokenizer.pad_token_id, + seed=args.seed, ) args.samples_per_gpu = ( args.effective_batch_size // grad_accum // torch.distributed.get_world_size() diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 0dd7acbf..42fd7d17 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -25,16 +25,156 @@ # Standard from typing import List, Optional +import os # Third Party -from torch.utils.data import Sampler +from torch.utils.data import DataLoader, Sampler import numba import numpy as np import torch.distributed as dist +# First Party +from instructlab.training.utils import make_collate_fn + + +def guess_starting_avg_padding(base_avg, goal, num_gpus, grad_accum, sorted_lengths): + """ + Return a starting middle point for the binary search + (to find optimal addition to packing_max_batch_len + to account for padding) + + Uses the largest initial bucket to approximate an + upper-bound for average padding, should overshoot. + """ + addition = 0 + packing_max_batch_len = int( + (base_avg + addition) * ((goal / num_gpus) / grad_accum) + ) + + bucket_zero = [] + max = sorted_lengths[0] + sum = 0 + for length in sorted_lengths: + if sum + max <= packing_max_batch_len: + sum += max + bucket_zero.append(length) + else: + break + + total_pad = 0 + for length in bucket_zero: + total_pad += max - length + addition = round(total_pad / len(bucket_zero)) + return addition + + +def simulate_buckets( + base_avg, + goal, + num_gpus, + grad_accum, + pad_id, + max_batch_len, + lengths, + seed, + dataset, + addition, +): + """ + Given an addition to packing_max_batch_len, simulate the + packing to find the updated average effective batch size. + """ + packing_max_batch_len = int( + (base_avg + addition) * ((goal / num_gpus) / grad_accum) + ) + + collate_fn = make_collate_fn(pad_id, is_granite=False, max_batch_len=max_batch_len) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + sampler = MultipackDistributedBatchSampler( + batch_max_length=packing_max_batch_len, + lengths=lengths, + num_replicas=world_size, + rank=rank, + seed=seed, + padding=True, + ) + simulation_loader = DataLoader( + dataset, + batch_sampler=sampler, + num_workers=8, + collate_fn=collate_fn, + ) + + avg_ebs = len(dataset) / len(simulation_loader) + return avg_ebs + + +def find_padding_max_batch_len_addition( + base_avg, goal, dataset, num_gpus, grad_accum, pad_id, max_batch_len, seed +): + """ + Do a modified binary search to find optimal padding addition for + packing_maximum_batch_len. Starts with an upper-bound guess, and + increases upper-bound until guess overshoots. Then perform standard + binary search until within a threshold for average effective batch + size. + """ + lengths = dataset.get_lengths() + sorted_lengths = list(lengths) + sorted_lengths.sort(reverse=True) + + # Use first default bucket avg padding as starting value for addition + addition = guess_starting_avg_padding( + base_avg, goal, num_gpus, grad_accum, sorted_lengths + ) + + # binary search correct addition value from starting value + first_over_hit = False + l = 0 + r = 2 * addition + while r - l > 1: + avg_ebs = simulate_buckets( + base_avg, + goal, + num_gpus, + grad_accum, + pad_id, + max_batch_len, + lengths, + seed, + dataset, + addition, + ) + + # check if simulation resulted in batch sizes close enough to goal and adjust if needed + if abs(avg_ebs - goal) <= max(10, round(goal * 0.02)): + break + + if avg_ebs > goal: + first_over_hit = True + r = addition + elif avg_ebs < goal: + if not first_over_hit: + # If the starting midpoint failed to overshoot, increase the bounds of the search + r = r * 2 + else: + l = addition + addition = l + ((r - l) // 2) + + return addition + def find_packing_max_batch_len_and_grad_accum( - num_gpus, avg_sample_len, effective_batch_size, max_batch_len_per_gpu + num_gpus, + avg_sample_len, + effective_batch_size, + max_batch_len_per_gpu, + is_padding, + dataset, + pad_id, + seed, ): """ Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length. @@ -58,12 +198,27 @@ def find_packing_max_batch_len_and_grad_accum( without exceeding the per-GPU limit, and the second element is the minimum number of gradient accumulation steps required to maintain the effective batch size. """ + packing_max_batch_len = max_batch_len_per_gpu + 1 grad_accum = 0 while packing_max_batch_len > max_batch_len_per_gpu: grad_accum += 1 - total_micro_batch = effective_batch_size / grad_accum - packing_max_batch_len = int(avg_sample_len * total_micro_batch / num_gpus) + total_micro_batch = (effective_batch_size / grad_accum) / num_gpus + if is_padding: + addition = find_padding_max_batch_len_addition( + avg_sample_len, + effective_batch_size, + dataset, + num_gpus, + grad_accum, + pad_id, + max_batch_len_per_gpu, + seed, + ) + else: + addition = 0 + packing_max_batch_len = int((avg_sample_len + addition) * total_micro_batch) + return packing_max_batch_len, grad_accum diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index 374714ae..e36b7395 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -6,11 +6,10 @@ from torch.utils.data import DataLoader, Dataset import numpy as np import torch -import torch.nn.functional as F # First Party from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler -from instructlab.training.utils import log_rank_0 +from instructlab.training.utils import log_rank_0, make_collate_fn class TokenDataset(Dataset): @@ -66,94 +65,6 @@ def get_lengths(self): return np.array([len(self.input_ids[0])] * len(self.input_ids)) -def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000): - rank = int(os.environ["RANK"]) - if is_granite: - - def pad_collate_fn(batch): - lens = np.array([len(item["input_ids"]) for item in batch]) - - cumsum_lens = np.cumsum(lens) - valid_up_to = int((cumsum_lens < max_batch_len).sum()) - total_len = cumsum_lens[valid_up_to - 1] - - batch = batch[:valid_up_to] - input_ids = [x["input_ids"].tolist() for x in batch] - labels = [x["labels"].tolist() for x in batch] - num_loss_counted_tokens = sum( - [(x["labels"] != -100).sum().item() for x in batch] - ) - - print( - f"\033[96m total length: {total_len} dropped: {cumsum_lens[-1] - total_len} " - f"num samples {len(batch)} - rank: {rank} " - f"max len: {lens.max()} min len: {lens.min()} avg len: {lens.mean()} " - f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" - ) - - return { - "input_ids": input_ids, - "labels": labels, - "num_loss_counted_tokens": num_loss_counted_tokens, - } - - else: - - def pad_collate_fn(batch): - lens = np.array([len(item["input_ids"]) for item in batch]) - max_len = max(lens) - - input_ids = torch.stack( - [ - F.pad( - item["input_ids"], - (max_len - len(item["input_ids"]), 0), - mode="constant", - value=pad_token_id, - ) - for item in batch - ] - ) - labels = torch.stack( - [ - F.pad( - item["labels"], - (max_len - len(item["labels"]), 0), - mode="constant", - value=-100, - ) - for item in batch - ] - ) - num_loss_counted_tokens = (labels != -100).sum() - - attention_mask = torch.stack( - [ - F.pad( - item["attention_mask"], - (max_len - len(item["attention_mask"]), 0), - mode="constant", - value=0, - ) - for item in batch - ] - ) - print( - f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} " - f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " - f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" - ) - - return { - "input_ids": input_ids, - "labels": labels, - "num_loss_counted_tokens": num_loss_counted_tokens, - "attention_mask": attention_mask, - } - - return pad_collate_fn - - def setup_dataset( data_path: str, mock: bool = False, diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 2bcb1718..6feaa548 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -21,6 +21,7 @@ from torch.distributed.fsdp import StateDictType import numpy as np import torch +import torch.nn.functional as F def add_noisy_embeddings(model, noise_alpha=None): @@ -75,6 +76,94 @@ def __init__(self, *args, **kwargs): break +def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000): + rank = int(os.environ["RANK"]) + if is_granite: + + def pad_collate_fn(batch): + lens = np.array([len(item["input_ids"]) for item in batch]) + + cumsum_lens = np.cumsum(lens) + valid_up_to = int((cumsum_lens < max_batch_len).sum()) + total_len = cumsum_lens[valid_up_to - 1] + + batch = batch[:valid_up_to] + input_ids = [x["input_ids"].tolist() for x in batch] + labels = [x["labels"].tolist() for x in batch] + num_loss_counted_tokens = sum( + [(x["labels"] != -100).sum().item() for x in batch] + ) + + print( + f"\033[96m total length: {total_len} dropped: {cumsum_lens[-1] - total_len} " + f"num samples {len(batch)} - rank: {rank} " + f"max len: {lens.max()} min len: {lens.min()} avg len: {lens.mean()} " + f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" + ) + + return { + "input_ids": input_ids, + "labels": labels, + "num_loss_counted_tokens": num_loss_counted_tokens, + } + + else: + + def pad_collate_fn(batch): + lens = np.array([len(item["input_ids"]) for item in batch]) + max_len = max(lens) + + input_ids = torch.stack( + [ + F.pad( + item["input_ids"], + (max_len - len(item["input_ids"]), 0), + mode="constant", + value=pad_token_id, + ) + for item in batch + ] + ) + labels = torch.stack( + [ + F.pad( + item["labels"], + (max_len - len(item["labels"]), 0), + mode="constant", + value=-100, + ) + for item in batch + ] + ) + num_loss_counted_tokens = (labels != -100).sum() + + attention_mask = torch.stack( + [ + F.pad( + item["attention_mask"], + (max_len - len(item["attention_mask"]), 0), + mode="constant", + value=0, + ) + for item in batch + ] + ) + print( + f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} " + f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " + f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" + ) + + return { + "input_ids": input_ids, + "labels": labels, + "num_loss_counted_tokens": num_loss_counted_tokens, + "attention_mask": attention_mask, + } + + return pad_collate_fn + + def convert_loss_to_reduce_sum(model, is_granite=False): """ this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths.