Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various bugfixes #856

Merged
merged 4 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/llama-2/tiny-llama.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
Expand Down
8 changes: 4 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,16 +543,16 @@ def build(self, total_num_steps):
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor

if self.cfg.eval_steps:
if self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
elif self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def sdp_attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
winglian marked this conversation as resolved.
Show resolved Hide resolved
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def xformers_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _build_result(self, instruction, input_text, output):
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
if self.system_no_input_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
Expand Down
21 changes: 12 additions & 9 deletions src/axolotl/utils/samplers/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ def _len_est(self):
)

# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.batch_max_len
)
- 1
return min(
1,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.batch_max_len
)
- 1
),
)
45 changes: 23 additions & 22 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,31 +142,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):


def calculate_total_num_steps(cfg, train_dataset):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens

if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens

if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens

if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens

if cfg.sample_packing_eff_est:
total_num_steps = (
Expand Down