Skip to content

Commit

Permalink
Add training callback to send predictions to WandB table (#521)
Browse files Browse the repository at this point in the history
* WIP Add training callback to send predictions to WandB table

* WIP improve wandb table reporting callback

* WIP improve wandb table reporting callback (cont)

* Add VSCode launching for debugging

* Add tiny llama example

* WIP attempt to improve post-eval prediction generation for table

* WIP attempt to improve post-eval prediction generation for table - part 2

* WIP batch generation

* WIP attempt to handle sample_packing using position_ids for wandb prediction table

* WIP add code for debugging

* Fix sample_packing support for wandb prediction table

* Clean up code for PR review

* Add eval_table_size, eval_table_max_new_tokens configs & clean up code

* Clean up PR, delete VSCode config, add tiny-llama example

* Add eval_table_size, eval_table_max_new_tokens documentation. Fix linting/formatting
  • Loading branch information
Glavin001 committed Sep 13, 2023
1 parent 2f586d1 commit 5b67ea9
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time
max_steps:

eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128

# save model as safetensors (require safetensors package)
save_safetensors:

Expand Down
2 changes: 2 additions & 0 deletions examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
Expand Down
1 change: 1 addition & 0 deletions examples/llama-2/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
save_steps:
debug:
deepspeed:
Expand Down
69 changes: 69 additions & 0 deletions examples/llama-2/tiny-llama.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
4 changes: 3 additions & 1 deletion src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

if cu_seqlens is not None and max_seqlen is not None:
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
Expand Down Expand Up @@ -261,6 +261,8 @@ def flashattn_forward(
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
Expand Down
191 changes: 191 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import pandas as pd
import torch
import torch.distributed as dist
import wandb
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
GenerationConfig,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
Expand Down Expand Up @@ -323,3 +326,191 @@ def on_evaluate(
metrics[key] = val

return BenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False

def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
eval_table_size = self.cfg.eval_table_size

if eval_table_size <= 0:
return control

trainer.model.eval()
device = torch.device(self.cfg.device)

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

def logits_to_tokens(logits) -> str:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table(
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0

for batch in tqdm(table_dataloader):
if row_index > eval_table_size:
break

batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])

(_, batch_logits, _) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)

prompt_token_ids_list = []
pred_step_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids, logits in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
batch_logits,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

pred_step_token_ids = logits_to_tokens(
logits[start : end + 1]
)[tokens_with_loss]
pred_step_token_ids_list.append(pred_step_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
pred_step_texts = tokenizer.batch_decode(
pred_step_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

for (
prompt_text,
completion_text,
prediction_text,
pred_step_text,
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
row_index += 1

wandb.run.log({f"{name} - Predictions vs Ground Truth": table})

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

return control

return LogPredictionCallback
2 changes: 2 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def normalize_config(cfg):
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ def load_model(
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len

Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
Expand Down Expand Up @@ -703,6 +704,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)

if cfg.use_wandb and cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))

if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

Expand Down

0 comments on commit 5b67ea9

Please sign in to comment.