Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multipack padding fix #19

Merged
merged 10 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ def main(args):
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not args.is_granite,
dataset=dataset,
pad_id=tokenizer.pad_token_id,
seed=args.seed,
)
args.samples_per_gpu = (
args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
Expand Down
163 changes: 159 additions & 4 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,156 @@

# Standard
from typing import List, Optional
import os

# Third Party
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, Sampler
import numba
import numpy as np
import torch.distributed as dist

# First Party
from instructlab.training.utils import make_collate_fn


def guess_starting_avg_padding(base_avg, goal, num_gpus, grad_accum, sorted_lengths):
"""
Return a starting middle point for the binary search
(to find optimal addition to packing_max_batch_len
to account for padding)

Uses the largest initial bucket to approximate an
upper-bound for average padding, should overshoot.
"""
addition = 0
packing_max_batch_len = int(
(base_avg + addition) * ((goal / num_gpus) / grad_accum)
)

bucket_zero = []
max = sorted_lengths[0]
sum = 0
for length in sorted_lengths:
if sum + max <= packing_max_batch_len:
sum += max
bucket_zero.append(length)
else:
break

total_pad = 0
for length in bucket_zero:
total_pad += max - length
addition = round(total_pad / len(bucket_zero))
return addition


def simulate_buckets(
base_avg,
goal,
num_gpus,
grad_accum,
pad_id,
max_batch_len,
lengths,
seed,
dataset,
addition,
):
"""
Given an addition to packing_max_batch_len, simulate the
packing to find the updated average effective batch size.
"""
packing_max_batch_len = int(
(base_avg + addition) * ((goal / num_gpus) / grad_accum)
)

collate_fn = make_collate_fn(pad_id, is_granite=False, max_batch_len=max_batch_len)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

sampler = MultipackDistributedBatchSampler(
batch_max_length=packing_max_batch_len,
lengths=lengths,
num_replicas=world_size,
rank=rank,
seed=seed,
padding=True,
)
simulation_loader = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
)

avg_ebs = len(dataset) / len(simulation_loader)
return avg_ebs


def find_padding_max_batch_len_addition(
base_avg, goal, dataset, num_gpus, grad_accum, pad_id, max_batch_len, seed
):
"""
Do a modified binary search to find optimal padding addition for
packing_maximum_batch_len. Starts with an upper-bound guess, and
increases upper-bound until guess overshoots. Then perform standard
binary search until within a threshold for average effective batch
size.
"""
lengths = dataset.get_lengths()
sorted_lengths = list(lengths)
sorted_lengths.sort(reverse=True)

# Use first default bucket avg padding as starting value for addition
addition = guess_starting_avg_padding(
base_avg, goal, num_gpus, grad_accum, sorted_lengths
)

# binary search correct addition value from starting value
first_over_hit = False
l = 0
r = 2 * addition
while r - l > 1:
avg_ebs = simulate_buckets(
base_avg,
goal,
num_gpus,
grad_accum,
pad_id,
max_batch_len,
lengths,
seed,
dataset,
addition,
)

# check if simulation resulted in batch sizes close enough to goal and adjust if needed
if abs(avg_ebs - goal) <= max(10, round(goal * 0.02)):
break

if avg_ebs > goal:
first_over_hit = True
r = addition
elif avg_ebs < goal:
if not first_over_hit:
# If the starting midpoint failed to overshoot, increase the bounds of the search
r = r * 2
else:
l = addition
addition = l + ((r - l) // 2)

return addition


def find_packing_max_batch_len_and_grad_accum(
num_gpus, avg_sample_len, effective_batch_size, max_batch_len_per_gpu
num_gpus,
avg_sample_len,
effective_batch_size,
max_batch_len_per_gpu,
is_padding,
dataset,
pad_id,
seed,
):
"""
Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length.
Expand All @@ -58,12 +198,27 @@ def find_packing_max_batch_len_and_grad_accum(
without exceeding the per-GPU limit, and the second element is the minimum number of gradient
accumulation steps required to maintain the effective batch size.
"""

packing_max_batch_len = max_batch_len_per_gpu + 1
grad_accum = 0
while packing_max_batch_len > max_batch_len_per_gpu:
grad_accum += 1
total_micro_batch = effective_batch_size / grad_accum
packing_max_batch_len = int(avg_sample_len * total_micro_batch / num_gpus)
total_micro_batch = (effective_batch_size / grad_accum) / num_gpus
if is_padding:
addition = find_padding_max_batch_len_addition(
avg_sample_len,
effective_batch_size,
dataset,
num_gpus,
grad_accum,
pad_id,
max_batch_len_per_gpu,
seed,
)
else:
addition = 0
packing_max_batch_len = int((avg_sample_len + addition) * total_micro_batch)

return packing_max_batch_len, grad_accum


Expand Down
91 changes: 1 addition & 90 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Standard

Check warning on line 1 in src/instructlab/training/token_dataset.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files
import os

# Third Party
Expand All @@ -6,11 +6,10 @@
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch
import torch.nn.functional as F

# First Party
from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler
from instructlab.training.utils import log_rank_0
from instructlab.training.utils import log_rank_0, make_collate_fn


class TokenDataset(Dataset):
Expand Down Expand Up @@ -66,94 +65,6 @@
return np.array([len(self.input_ids[0])] * len(self.input_ids))


def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000):
rank = int(os.environ["RANK"])
if is_granite:

def pad_collate_fn(batch):
lens = np.array([len(item["input_ids"]) for item in batch])

cumsum_lens = np.cumsum(lens)
valid_up_to = int((cumsum_lens < max_batch_len).sum())
total_len = cumsum_lens[valid_up_to - 1]

batch = batch[:valid_up_to]
input_ids = [x["input_ids"].tolist() for x in batch]
labels = [x["labels"].tolist() for x in batch]
num_loss_counted_tokens = sum(
[(x["labels"] != -100).sum().item() for x in batch]
)

print(
f"\033[96m total length: {total_len} dropped: {cumsum_lens[-1] - total_len} "
f"num samples {len(batch)} - rank: {rank} "
f"max len: {lens.max()} min len: {lens.min()} avg len: {lens.mean()} "
f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m"
)

return {
"input_ids": input_ids,
"labels": labels,
"num_loss_counted_tokens": num_loss_counted_tokens,
}

else:

def pad_collate_fn(batch):
lens = np.array([len(item["input_ids"]) for item in batch])
max_len = max(lens)

input_ids = torch.stack(
[
F.pad(
item["input_ids"],
(max_len - len(item["input_ids"]), 0),
mode="constant",
value=pad_token_id,
)
for item in batch
]
)
labels = torch.stack(
[
F.pad(
item["labels"],
(max_len - len(item["labels"]), 0),
mode="constant",
value=-100,
)
for item in batch
]
)
num_loss_counted_tokens = (labels != -100).sum()

attention_mask = torch.stack(
[
F.pad(
item["attention_mask"],
(max_len - len(item["attention_mask"]), 0),
mode="constant",
value=0,
)
for item in batch
]
)
print(
f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} "
f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} "
f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m"
)

return {
"input_ids": input_ids,
"labels": labels,
"num_loss_counted_tokens": num_loss_counted_tokens,
"attention_mask": attention_mask,
}

return pad_collate_fn


def setup_dataset(
data_path: str,
mock: bool = False,
Expand Down
Loading
Loading