Skip to content

Commit

Permalink
gather/broadcast the max value of the packing efficiency automatically (
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 17, 2023
1 parent ab534d7 commit b15b19e
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 12 deletions.
88 changes: 88 additions & 0 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,91 @@ def broadcast_dict(vals: dict):
vals = pickle.loads(data_byte) # nosec

return vals


def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
Compute a value using the function 'fn' only on the specified rank (default is 0).
The value is then broadcasted to all other ranks.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that computes the value. Default is 0.
Returns:
- The computed value (int or float).
"""
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
else:
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor

# Broadcast the tensor to all processes.
barrier()
dist.broadcast(value_tensor, src=0)

# Convert the tensor back to its original type (int or float)
if value_tensor == value_tensor.int():
return int(value_tensor.item())
return float(value_tensor.item())


def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()

# Placeholder tensor for gathering results
if is_main_process():
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
else:
gathered_tensors = None

dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)

if is_main_process():
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None


def reduce_and_broadcast(fn1, fn2):
"""
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
and then broadcast the reduced result to all ranks.
Args:
- fn1 (callable): A function that computes the value on each rank.
- fn2 (callable): A reduction function that takes a list of values and returns a single value.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- The reduced and broadcasted value.
"""

# Gather values from all ranks using fn1
if not is_distributed():
return fn2([fn1()])

gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())

# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))
44 changes: 32 additions & 12 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np
import torch
import torch.cuda
import torch.distributed as dist
import transformers
from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR
Expand All @@ -35,7 +36,12 @@
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.distributed import (
is_distributed,
is_main_process,
reduce_and_broadcast,
zero_first,
)
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

LOG = logging.getLogger("axolotl")
Expand Down Expand Up @@ -456,7 +462,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
)
else:
sampler = RandomSampler(train_dataset)
if cfg.world_size > 1 and is_distributed():
sampler = DistributedSampler(
train_dataset,
num_replicas=cfg.world_size,
rank=dist.get_rank(),
seed=cfg.seed or 42,
)
else:
sampler = RandomSampler(train_dataset)

data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
Expand All @@ -474,18 +489,23 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.floor(
data_loader_len
* cfg.micro_batch_size
* cfg.num_epochs
// cfg.batch_size
)
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))

def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return max(estimates)

sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: actual_eff,
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
)
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
)
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
cfg.sample_packing_eff_est = sample_packing_eff_est
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
Expand Down

0 comments on commit b15b19e

Please sign in to comment.