Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vram fix attempt #1164

Merged
merged 6 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def load_tokenized_prepared_datasets(
(
str(cfg.sequence_len)
+ "@"
+ str(cfg.sample_packing)
+ "@"
+ str(cfg.eval_sample_packing)
+ "@"
+ str(cfg.group_by_length)
+ "@"
+ "|".join(
sorted(
[
Expand Down Expand Up @@ -162,7 +168,7 @@ def load_tokenized_prepared_datasets(
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
)

if cfg.seed:
Expand Down
12 changes: 6 additions & 6 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths
54 changes: 27 additions & 27 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,33 @@ def drop_long_seq(sample, sequence_len=2048):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
if cfg.is_preprocess:
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

# Phi doesn't want the attention_mask feature when training
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
):
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
Expand All @@ -130,33 +157,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
load_from_cache_file=not cfg.is_preprocess,
)

if cfg.group_by_length or cfg.sample_packing:
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

# Phi doesn't want the attention_mask feature when training
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
):
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

return train_dataset, eval_dataset


Expand Down
Loading