Skip to content

Commit

Permalink
gather/broadcast the max value of the packing efficiency automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Aug 23, 2023
1 parent d5dcf9c commit d45b5d3
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 11 deletions.
89 changes: 89 additions & 0 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from contextlib import contextmanager

import torch
import torch.distributed as dist
from accelerate import Accelerator

Expand Down Expand Up @@ -53,3 +54,91 @@ def zero_first(is_main):
yield
if is_main: # then rank 0 waits after it has run the context
barrier()


def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
Compute a value using the function 'fn' only on the specified rank (default is 0).
The value is then broadcasted to all other ranks.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that computes the value. Default is 0.
Returns:
- The computed value (int or float).
"""
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
else:
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor

# Broadcast the tensor to all processes.
barrier()
dist.broadcast(value_tensor, src=0)

# Convert the tensor back to its original type (int or float)
if value_tensor == value_tensor.int():
return int(value_tensor.item())
return float(value_tensor.item())


def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()

# Placeholder tensor for gathering results
if is_main_process():
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
else:
gathered_tensors = None

dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)

if is_main_process():
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None


def reduce_and_broadcast(fn1, fn2):
"""
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
and then broadcast the reduced result to all ranks.
Args:
- fn1 (callable): A function that computes the value on each rank.
- fn2 (callable): A reduction function that takes a list of values and returns a single value.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- The reduced and broadcasted value.
"""

# Gather values from all ranks using fn1
if not is_distributed():
return fn2([fn1()])

gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())

# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))
36 changes: 25 additions & 11 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

import bitsandbytes as bnb
import numpy as np
import torch.cuda
import torch.distributed as dist
import transformers
from datasets import Dataset, set_caching_enabled
from torch import nn
Expand All @@ -31,6 +32,7 @@
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_distributed, reduce_and_broadcast
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
Expand Down Expand Up @@ -331,7 +333,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
)
else:
sampler = RandomSampler(train_dataset)
if cfg.world_size > 1 and is_distributed():
sampler = DistributedSampler(
train_dataset,
num_replicas=cfg.world_size,
rank=dist.get_rank(),
seed=cfg.seed or 42,
)
else:
sampler = RandomSampler(train_dataset)

data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
Expand All @@ -349,18 +360,21 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.floor(
data_loader_len
* cfg.micro_batch_size
* cfg.num_epochs
// cfg.batch_size
)
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))

def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return max(estimates)

sample_packing_eff_est = reduce_and_broadcast(
lambda: math.ceil(actual_eff * 100.0) / 100.0,
calc_sample_packing_eff_est,
)
sample_packing_eff_est = math.ceil(sample_packing_eff_est * 100.0) / 100.0
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
)
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
cfg.sample_packing_eff_est = sample_packing_eff_est
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
Expand Down

0 comments on commit d45b5d3

Please sign in to comment.