Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training callback to send predictions to WandB table #521

Merged
merged 16 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time
max_steps:

eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128

# save model as safetensors (require safetensors package)
save_safetensors:

Expand Down
2 changes: 2 additions & 0 deletions examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
Expand Down
1 change: 1 addition & 0 deletions examples/llama-2/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
save_steps:
debug:
deepspeed:
Expand Down
69 changes: 69 additions & 0 deletions examples/llama-2/tiny-llama.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
4 changes: 3 additions & 1 deletion src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

if cu_seqlens is not None and max_seqlen is not None:
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
winglian marked this conversation as resolved.
Show resolved Hide resolved
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
Expand Down Expand Up @@ -262,6 +262,8 @@ def flashattn_forward(
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
Expand Down
191 changes: 191 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import pandas as pd
import torch
import torch.distributed as dist
import wandb
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
GenerationConfig,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
Expand Down Expand Up @@ -317,3 +320,191 @@ def on_evaluate(
trainer.log(results)

return BenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False

def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
eval_table_size = self.cfg.eval_table_size

if eval_table_size <= 0:
return control

trainer.model.eval()
device = torch.device(self.cfg.device)

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

def logits_to_tokens(logits) -> str:
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table(
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0

for batch in tqdm(table_dataloader):
if row_index > eval_table_size:
break

batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])

(_, batch_logits, _) = trainer.prediction_step(
trainer.model,
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
batch,
prediction_loss_only=False,
)

prompt_token_ids_list = []
pred_step_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids, logits in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
batch_logits,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

pred_step_token_ids = logits_to_tokens(
logits[start : end + 1]
)[tokens_with_loss]
pred_step_token_ids_list.append(pred_step_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
pred_step_texts = tokenizer.batch_decode(
pred_step_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer.prediction_step(...) might be easier to use

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was. However, I had that previously and it appeared to output strange predictions. At one point I actually had both to compare and only model.generate was useful.

I'm struggling to find the WandB report. I'll try again and see how it goes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top is trainer.prediction_step and bottom is trainer.model.generate:

image

Not sure if this is correct though? I originally showed both, I could add it too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added both now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

**prompt_encoding, generation_config=generation_config
)

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

for (
prompt_text,
completion_text,
prediction_text,
pred_step_text,
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
row_index += 1

wandb.run.log({f"{name} - Predictions vs Ground Truth": table})

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

return control

return LogPredictionCallback
2 changes: 2 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def normalize_config(cfg):
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
winglian marked this conversation as resolved.
Show resolved Hide resolved
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def load_model(
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len

Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
Expand Down Expand Up @@ -719,6 +720,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)

if cfg.use_wandb and cfg.eval_table_size > 0:
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))

if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

Expand Down