Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training callback to send predictions to WandB table #521

Merged
merged 16 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Remote Attach",
"type": "python",
"request": "attach",
"connect": {
"host": "0.0.0.0",
"port": 5678
},
"pathMappings": [
{
"localRoot": "${workspaceFolder}",
"remoteRoot": "/workspace/axolotl/"
}
],
"justMyCode": false
},
{
"name": "train",
"type": "python",
"request": "launch",
"module": "accelerate.commands.launch",
"args": [
"${workspaceFolder}/scripts/finetune.py",
// "${file}",
"${workspaceFolder}/examples/llama-2/tiny-random.yml",
], // other args comes after train.py
"console": "integratedTerminal",
// "env": {"CUDA_LAUNCH_BLOCKING": "1"}
},
]
}
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 7 additions & 4 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# version: '3.8'
services:
axolotl:
build:
context: .
dockerfile: ./docker/Dockerfile
# build:
# context: .
# dockerfile: ./docker/Dockerfile
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
image: winglian/axolotl:main-py3.10-cu118-2.0.1
volumes:
- .:/workspace/axolotl
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
- WANDB_API_KEY=${WANDB_API_KEY}
ports:
- "5678:5678"
deploy:
resources:
reservations:
devices:
- driver: nvidia
# count: 1
count: 1
capabilities: [gpu]
command: tail -f /dev/null
101 changes: 101 additions & 0 deletions examples/llama-2/tiny-random.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# anushehchaudry/llama-2-tiny-random
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
# base_model: anushehchaudry/llama-2-tiny-random
# base_model_config: anushehchaudry/llama-2-tiny-random

# base_model: JackFram/llama-68m
# base_model_config: JackFram/llama-68m

base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
# - path: mhenrichsen/alpaca_2k_test
# type: alpaca
# - path: teknium/GPT4-LLM-Cleaned
# type: alpaca
- path: Glavin001/startup-interviews
type: alpaca
dataset_prepared_path: last_run_prepared
# val_set_size: 0.01
val_set_size: 0.02
# val_set_size: 0.05
# val_set_size: 0.001
# val_set_size: 0.1
# output_dir: ./lora-out
# output_dir: ./lora-2-out
output_dir: ./lora-6-out

# sequence_len: 4096
# sequence_len: 2048
# sequence_len: 256
# sequence_len: 512
sequence_len: 1024
sample_packing: true
# sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project: test-issue-490
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
# micro_batch_size: 2
micro_batch_size: 16
# micro_batch_size: 24
# micro_batch_size: 24
# num_epochs: 3
# num_epochs: 0.001
# num_epochs: 0.01
# num_epochs: 1
# num_epochs: 5
num_epochs: 10
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
# eval_steps: 10
# eval_steps: 20
eval_steps: 2
# eval_steps: 1
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
5 changes: 5 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars

# import debugpy
# debugpy.listen(('0.0.0.0', 5678))
# debugpy.wait_for_client()
# debugpy.breakpoint()
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

if cu_seqlens is not None and max_seqlen is not None:
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
winglian marked this conversation as resolved.
Show resolved Hide resolved
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
Expand Down
149 changes: 149 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
GenerationConfig,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

import wandb
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
Expand Down Expand Up @@ -317,3 +320,149 @@ def on_evaluate(
trainer.log(results)

return BenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False

def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState,
control: TrainerControl,
train_dataloader,
eval_dataloader,
**kwargs,
):
trainer.model.eval()
device = torch.device(self.cfg.device)

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table(
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion",
]
)
row_index = 0
max_new_tokens = 128

for batch in tqdm(table_dataloader, total=len(table_dataloader)):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])

prompt_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids, batch_labels, batch_pos_ids
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id

prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved

prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer.prediction_step(...) might be easier to use

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was. However, I had that previously and it appeared to output strange predictions. At one point I actually had both to compare and only model.generate was useful.

I'm struggling to find the WandB report. I'll try again and see how it goes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top is trainer.prediction_step and bottom is trainer.model.generate:

image

Not sure if this is correct though? I originally showed both, I could add it too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added both now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

**prompt_encoding, generation_config=generation_config
)

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

for prompt_text, completion_text, prediction_text in zip(
prompt_texts, completion_texts, predicted_texts
):
table.add_data(
row_index, prompt_text, completion_text, prediction_text
)
row_index += 1

wandb.run.log({f"{name} - Predictions vs Ground Truth": table})

log_table_from_dataloader("Eval", eval_dataloader)

return control

return LogPredictionCallback
4 changes: 2 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def load_model(
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len

Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
Expand Down Expand Up @@ -719,6 +720,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)

if cfg.use_wandb:
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))

if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

Expand Down
Loading