Skip to content

Commit

Permalink
attention_mask not needed for training (#642)
Browse files Browse the repository at this point in the history
* attention_mask not needed for training

* specifically don't use attention mask for phi

* use a different check for phi

* small fixes since phi removed some values from their config
  • Loading branch information
winglian committed Sep 27, 2023
1 parent d887ad8 commit e8cbf50
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
8 changes: 2 additions & 6 deletions src/axolotl/models/phi/modeling_mixformer_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,8 @@ def __init__(
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx

self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
mlp_cls = mlp.pop("mlp_cls")
if mlp_cls == "fused_mlp":
self.mlp = FusedMLP(config=config, **mlp)
else:
self.mlp = MLP(config=config, **mlp)
self.mixer = MHA(config, layer_idx=block_idx)
self.mlp = MLP(config)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def prepare_dataset(cfg, tokenizer):

with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
cfg, train_dataset, eval_dataset, tokenizer
)
if cfg.max_steps:
total_num_steps = min(
Expand Down
9 changes: 8 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def disable_datasets_caching():
set_caching_enabled(True)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
Expand All @@ -414,6 +414,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=os.cpu_count()
)

# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

return train_dataset, eval_dataset


Expand Down

0 comments on commit e8cbf50

Please sign in to comment.