Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trim examples ahead of time for auto packing #994

Merged
merged 14 commits into from
Feb 27, 2024
46 changes: 27 additions & 19 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
'attention_mask',
'bidirectional_mask',
]

# Cut everything down to size
sizes, trimmed_examples = [], []
for idx in range(batch['attention_mask'].shape[0]):
size, trimmed_example = _extract_trim_batch_idx(batch, idx)
sizes.append(size)
trimmed_examples.append(trimmed_example)
sizes, trimmed_examples = _trim_batch(batch)
return self.pack_trimmed_examples(trimmed_examples, sizes)

def pack_trimmed_examples(self, trimmed_examples: List[Dict[str,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
torch.Tensor]],
sizes: List[int]) -> Dict[str, torch.Tensor]:
# Apply our CS 101 bin packing algorithm.
packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(
sizes=sizes,
Expand All @@ -102,6 +101,18 @@ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return batch


def _trim_batch(
irenedea marked this conversation as resolved.
Show resolved Hide resolved
batch: Dict[str, torch.Tensor]
) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]:
# Cut everything down to size
sizes, trimmed_examples = [], []
for idx in range(batch['attention_mask'].shape[0]):
size, trimmed_example = _extract_trim_batch_idx(batch, idx)
sizes.append(size)
trimmed_examples.append(trimmed_example)
return sizes, trimmed_examples


def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
example = {k: v[idx] for k, v in batch.items()}
Expand Down Expand Up @@ -385,18 +396,14 @@ def profile_packing(
# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))

def split_big_batch(raw_batch_size: int) -> List:
input_ids = big_batch['input_ids'].split(raw_batch_size)
batches = [{'input_ids': x} for x in input_ids]

for key in big_batch.keys():
if key == 'input_ids':
continue
for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
batches[idx].update({key: split})
return batches
# # Cut everything down to size
irenedea marked this conversation as resolved.
Show resolved Hide resolved
sizes, trimmed_examples = _trim_batch(big_batch)

def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
# Copy trimmed examples so that the dicts are not shared between profiling runs.
trimmed_examples_copy = [te.copy() for te in trimmed_examples]

# Create the packing collator.
packer = BinPackCollator(
collator=lambda x: x,
target_batch_size=device_batch_size,
Expand All @@ -406,10 +413,11 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
max_leftover_bins_to_keep=max_leftovers_to_keep)

# Simulate feeding the packing collator a bunch of data
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
for idx in range(0, len(trimmed_examples_copy), raw_batch_size):
batch = trimmed_examples_copy[idx:idx + raw_batch_size]
if len(batch) < device_batch_size:
continue
packer.pack(batch)
packer.pack_trimmed_examples(batch, sizes[idx:idx + raw_batch_size])

if packer.n_packed_examples == 0:
log.debug(
Expand Down
Loading