Skip to content

Commit

Permalink
Fixing sequence parallel error conditions and adding type float for m…
Browse files Browse the repository at this point in the history
…icrobatch_size in typehints (#3139)

* fixing the current error condition

* fixing error condition

* temporarily adding debug info

* temporarily adding debug info

* fixing divide by 0 error

* fixing a bug

* fixing a bug

* reverting changes from prev commits

* Update state.py

* Update state.py

* Update trainer.py
  • Loading branch information
ShashankMosaicML committed Mar 25, 2024
1 parent c91be1f commit 99b86dd
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
11 changes: 7 additions & 4 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,18 @@ def _check_list_is_primitives(l):
return True


def _default_split_batch(batch: Any, microbatch_size: int) -> Sequence:
def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequence:
"""Splits batch into chunks of size `microbatch_size` for gradient accumulation.
Works with tensors, dictionaries of tensors, (x, y) tuples, and lists where ``batch`` is the 2nd dimension.
Args:
batch (Any): output from the dataloader.
microbatch_size (int): Size of microbatches to batch into.
microbatch_size (int | float): Size of microbatches to batch into.
"""
if isinstance(microbatch_size, float):
raise ValueError('_default_split_batch does not support floating point microbatch_size.')

if isinstance(batch, torch.Tensor): # check for a single stack of tensors
return _split_tensor(batch, microbatch_size)
elif isinstance(batch, Mapping): # check for dictionary (hf style)
Expand Down Expand Up @@ -154,7 +157,7 @@ class DataSpec:
normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
the batch is not modified.
split_batch ((Batch, int) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
split a batch (the first parameter) into microbatches of a given size (the second parameter). If
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, Tuple, or List, then
this function must be specified.
Expand All @@ -180,7 +183,7 @@ def __init__(
num_samples: Optional[int] = None,
num_tokens: Optional[int] = None,
device_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, int], Sequence[Batch]]] = None,
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None,
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class State(Serializable):
``rank_zero_seed + dist.get_global_rank()``.
run_name (str): The name for this training run.
device (Device): The device used by this process. The trainer moves the model and loaded data to this device.
device_train_microbatch_size (int, optional): The microbatch size for each device during training.
device_train_microbatch_size (int | float, optional): The microbatch size for each device during training.
auto_microbatching (bool, optional): Whether automatic microbatching is enabled.
train_dataloader (Iterable, optional): Dataloader used for training
evaluators (Evaluator | Evaluators, optional): :class:`.Evaluator` used for evaluation.
Expand Down Expand Up @@ -308,7 +308,7 @@ class State(Serializable):
eval_timestamp (Timestamp): The timestamp for the current evaluation dataloader. This timestamp is reset
before the dataloader is evaluated. The :attr:`~Timestamp.epoch` attribute for this timestamp is always
``0``.
device_train_microbatch_size (int): The size of each train microbatch per device.
device_train_microbatch_size (int | float): The size of each train microbatch per device.
loss (torch.Tensor | Sequence[torch.Tensor] | Dict[Any, torch.Tensor]): The most recently computed loss.
model (torch.nn.Module): The training model.
Expand Down Expand Up @@ -381,7 +381,7 @@ def __init__(
max_duration: Optional[Union[str, Time[int]]] = None,

# data configurations
device_train_microbatch_size: Optional[int] = None,
device_train_microbatch_size: Optional[Union[int, float]] = None,
auto_microbatching: bool = False,

# dataloaders
Expand Down
11 changes: 8 additions & 3 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,17 +630,20 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch

def split_batch(self, batch: Any, microbatch_size: int) -> List[Dict[str, Any]]:
def split_batch(self, batch: Any, microbatch_size: Union[int, float]) -> List[Dict[str, Any]]:
"""
Handling for certain specialty columns that must be split into batches in different formats.
Args:
batch (Dict): Batch of data
microbatch_size (int): Size of microbatches
microbatch_size (int | float): Size of microbatches
Returns:
List: List of chunked batches
"""
if isinstance(microbatch_size, float):
raise ValueError('InContextLearningDataset does not support float microbatch sizes')

# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
Expand Down Expand Up @@ -1028,7 +1031,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
def get_num_samples_in_batch(self, batch) -> int:
return batch['input_ids'].shape[0] // self.num_choices

def split_batch(self, batch: Any, microbatch_size: int) -> List[Dict[str, Any]]:
def split_batch(self, batch: Any, microbatch_size: Union[int, float]) -> List[Dict[str, Any]]:
"""
Split batch while ensuring all continuations are in the same microbatch.
Expand All @@ -1044,6 +1047,8 @@ def split_batch(self, batch: Any, microbatch_size: int) -> List[Dict[str, Any]]:
Returns:
list: List of chunked batches
"""
if isinstance(microbatch_size, float):
raise ValueError('InContextLearningMultipleChoiceTaskDataset does not support float microbatch sizes')
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
Expand Down
28 changes: 14 additions & 14 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _set_evaluator_interval_and_subset_num_batches(
)


def _is_auto_microbatching(device_train_microbatch_size: Optional[Union[int, str]], device: Device):
def _is_auto_microbatching(device_train_microbatch_size: Optional[Union[int, float, str]], device: Device):
if device_train_microbatch_size == 'auto':
warnings.warn((
"`device_train_microbatch_size='auto'` may potentially fail with unexpected "
Expand All @@ -260,10 +260,10 @@ def _is_auto_microbatching(device_train_microbatch_size: Optional[Union[int, str


def _get_initial_device_train_microbatch_size(
device_train_microbatch_size: Optional[Union[int, str]],
device_train_microbatch_size: Optional[Union[int, float, str]],
auto_microbatching: bool,
train_dataloader: Optional[Iterable],
) -> Optional[int]:
) -> Optional[Union[int, float]]:
"""Sets initial value of device_train_microbatch_size.
If auto_microbatching, sets initial `device_train_microbatch_size` to per rank batch size. If
Expand Down Expand Up @@ -406,10 +406,10 @@ def _validate_evaluator(evaluator: Evaluator, device: Device):
'Auto microbatching on evaluators is not compatible with sequence parallelism. '
'Please manually set device_eval_microbatch_size or disable sequence parallelism .',
)
if isinstance(evaluator.dataloader.get_num_samples_in_batch, int) and hasattr(
if hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.get_num_samples_in_batch * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down Expand Up @@ -904,7 +904,7 @@ class Trainer:
training on GPU)
precision_config (Optional[Dict[str, Any]]): The config for FP8 scaling strategy. See parameters for
`DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_.
device_train_microbatch_size (Union[int, str), optional): The number of samples to process on each device per
device_train_microbatch_size (Union[int, float, str), optional): The number of samples to process on each device per
microbatch during training. Gradients are summed over the microbatches per device. If set to ``auto``,
dynamically decreases device_train_microbatch_size if microbatch is too large for GPU. (default: ``None``)
Expand Down Expand Up @@ -1043,7 +1043,7 @@ def __init__(
device: Optional[Union[str, Device]] = None,
precision: Optional[Union[str, Precision]] = None,
precision_config: Optional[Dict[str, Any]] = None,
device_train_microbatch_size: Optional[Union[int, str]] = None,
device_train_microbatch_size: Optional[Union[int, float, str]] = None,

# Reproducibility
seed: Optional[int] = None,
Expand Down Expand Up @@ -1114,10 +1114,10 @@ def __init__(
auto_microbatching = _is_auto_microbatching(device_train_microbatch_size, device=device)
if auto_microbatching and train_dataloader is not None and hasattr(train_dataloader, 'seq_parallel_world_size'):
raise ValueError('`device_train_microbatch_size="auto"` is not compatible with sequence parallelism.')
if isinstance(device_train_microbatch_size, int) and train_dataloader is not None and hasattr(
if train_dataloader is not None and hasattr(
train_dataloader,
'seq_parallel_world_size',
) and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
) and train_dataloader.seq_parallel_world_size > 1 and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down Expand Up @@ -1908,7 +1908,7 @@ def fit(
eval_interval: Union[int, str, Time, Callable[[State, Event], bool]] = 1,

# Numerics
device_train_microbatch_size: Optional[Union[int, str]] = None,
device_train_microbatch_size: Optional[Union[int, float, str]] = None,
precision: Optional[Union[str, Precision]] = None,
):
"""Train the model.
Expand Down Expand Up @@ -2017,7 +2017,7 @@ def fit(
eval_dataloader (Iterable | DataSpec | Evaluator | Sequence[Evaluator], optional): See :class:`.Trainer`.
eval_subset_num_batches (int, optional): See :class:`.Trainer`.
eval_interval (int | str | Time | (State, Event) -> bool, optional): See :class:`.Trainer`.
device_train_microbatch_size (int | str, optional): See :class:`.Trainer`.
device_train_microbatch_size (int | float | str, optional): See :class:`.Trainer`.
precision (Precision | str, optional): See :class:`.Trainer`.
"""
# Check Optimizer
Expand Down Expand Up @@ -2161,10 +2161,10 @@ def fit(
'seq_parallel_world_size',
):
raise ValueError('`device_train_microbatch_size="auto"` is not compatible with sequence parallelism.')
if isinstance(device_train_microbatch_size, int) and train_dataloader is not None and hasattr(
if train_dataloader is not None and hasattr(
train_dataloader,
'seq_parallel_world_size',
) and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
) and train_dataloader.seq_parallel_world_size > 1 and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down Expand Up @@ -2312,7 +2312,7 @@ def _accumulate_time_across_ranks(
)
dist.all_reduce(sample_token_tensor, reduce_operation='SUM')
if isinstance(num_samples, float):
sample_token_tensor_int = sample_token_tensor.to(torch.int)
sample_token_tensor_int = sample_token_tensor.round().to(torch.int)
if torch.any(torch.abs(sample_token_tensor_int - sample_token_tensor) > 1e-4):
raise ValueError('The sums of samples and tokens across ranks should each be integers.')
sample_token_tensor = sample_token_tensor_int
Expand Down

0 comments on commit 99b86dd

Please sign in to comment.