Skip to content

Commit

Permalink
Vram fix attempt (#1164) [skip ci]
Browse files Browse the repository at this point in the history
* revert order of filter/drop_long step and handle calc for max_input_len only during preprocessing

* revert some changes to preparing for packing to allow more flexibility

* prepare dataset for packing during pre-processing step

* prepare dataset hash based on sample packing too

* enclose none check

* just cast straight to string for ds hash
  • Loading branch information
winglian committed Jan 23, 2024
1 parent 802f966 commit 32580c1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
8 changes: 7 additions & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def load_tokenized_prepared_datasets(
(
str(cfg.sequence_len)
+ "@"
+ str(cfg.sample_packing)
+ "@"
+ str(cfg.eval_sample_packing)
+ "@"
+ str(cfg.group_by_length)
+ "@"
+ "|".join(
sorted(
[
Expand Down Expand Up @@ -162,7 +168,7 @@ def load_tokenized_prepared_datasets(
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
)

if cfg.seed:
Expand Down
12 changes: 6 additions & 6 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths
54 changes: 27 additions & 27 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,33 @@ def drop_long_seq(sample, sequence_len=2048):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
if cfg.is_preprocess:
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

# Phi doesn't want the attention_mask feature when training
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
):
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
Expand All @@ -130,33 +157,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
load_from_cache_file=not cfg.is_preprocess,
)

if cfg.group_by_length or cfg.sample_packing:
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

# Phi doesn't want the attention_mask feature when training
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
):
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

return train_dataset, eval_dataset


Expand Down

0 comments on commit 32580c1

Please sign in to comment.