Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training callback to send predictions to WandB table #521

Merged
merged 16 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Remote Attach",
"type": "python",
"request": "attach",
"connect": {
"host": "0.0.0.0",
"port": 5678
},
"pathMappings": [
{
"localRoot": "${workspaceFolder}",
"remoteRoot": "/workspace/axolotl/"
}
]
},
{
"name": "train",
"type": "python",
"request": "launch",
"module": "accelerate.commands.launch",
"args": [
"${workspaceFolder}/scripts/finetune.py",
// "${file}",
"${workspaceFolder}/examples/llama-2/tiny-random.yml",
], // other args comes after train.py
"console": "integratedTerminal",
// "env": {"CUDA_LAUNCH_BLOCKING": "1"}
},
]
}
11 changes: 7 additions & 4 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# version: '3.8'
services:
axolotl:
build:
context: .
dockerfile: ./docker/Dockerfile
# build:
# context: .
# dockerfile: ./docker/Dockerfile
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
image: winglian/axolotl:main-py3.10-cu118-2.0.1
volumes:
- .:/workspace/axolotl
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
- WANDB_API_KEY=${WANDB_API_KEY}
ports:
- "5678:5678"
deploy:
resources:
reservations:
devices:
- driver: nvidia
# count: 1
count: 1
capabilities: [gpu]
command: tail -f /dev/null
97 changes: 97 additions & 0 deletions examples/llama-2/tiny-random.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# anushehchaudry/llama-2-tiny-random
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
# base_model: anushehchaudry/llama-2-tiny-random
# base_model_config: anushehchaudry/llama-2-tiny-random

# base_model: JackFram/llama-68m
# base_model_config: JackFram/llama-68m

base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
# - path: mhenrichsen/alpaca_2k_test
# type: alpaca
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
# - path: Glavin001/startup-interviews
# type: alpaca
dataset_prepared_path: last_run_prepared
# val_set_size: 0.01
val_set_size: 0.001
# val_set_size: 0.1
# output_dir: ./lora-out
# output_dir: ./lora-2-out
output_dir: ./lora-5-out

# sequence_len: 4096
sequence_len: 2048
# sequence_len: 256
# sequence_len: 512
# sample_packing: true
sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project: test-issue-490
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
# micro_batch_size: 2
micro_batch_size: 16
# micro_batch_size: 24
# micro_batch_size: 24
# num_epochs: 3
# num_epochs: 0.001
# num_epochs: 0.01
num_epochs: 1
# num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 10
# eval_steps: 20
# eval_steps: 2
# eval_steps: 1
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
5 changes: 5 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars

# import debugpy
# debugpy.listen(('0.0.0.0', 5678))
# debugpy.wait_for_client()
# debugpy.breakpoint()
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
Expand Down
178 changes: 178 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from typing import TYPE_CHECKING, Dict, List

import itertools
import evaluate
import numpy as np
import pandas as pd
Expand All @@ -15,13 +16,16 @@
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
GenerationConfig,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

import wandb
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
Expand Down Expand Up @@ -317,3 +321,177 @@ def on_evaluate(
trainer.log(results)

return BenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
LOG.info("log_prediction_callback_factory")

class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False

def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState,
control: TrainerControl,
model,
# tokenizer,
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
train_dataloader,
eval_dataloader,
**kwargs,
):
LOG.info("=" * 80)
LOG.info("logging predictions")

trainer.model.eval()
device = torch.device(self.cfg.device)

def logits_to_tokens(logits) -> str:
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i-1))
start = i
ranges.append((start, len(lst)-1)) # for the last range
return ranges

def log_table_from_dataloader(name: str, table_dataloader):

# Initialize an empty wandb.Table
table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion 1", "Predicted Completion 2"])

batch_index = 0
for batch in tqdm(table_dataloader, total=len(table_dataloader)):
# For each batch I want prompt, completion, 2x predictions

# (loss, logits, labels) = trainer.prediction_step(
# (batch_loss, batch_logits, batch_labels) = trainer.prediction_step(
# trainer.model,
# batch,
# prediction_loss_only=False,
# )

batch_labels = batch['labels'].to(device)
batch_input_ids = batch['input_ids'].to(device)

if 'position_ids' in batch:
batch_pos_ids = batch['position_ids'].tolist()
else:
batch_pos_ids = [None] * len(batch['input_ids'])

prompt_token_ids_list = []
completion_token_ids_list = []
# completion_texts = []
# prediction_texts = []

# for input_ids in batch['input_ids']:
# for batch_item_idx, (input_ids, labels) in enumerate(zip(batch['input_ids'], logits, labels)):
# for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'].to(device), batch_logits, batch_labels)):
for batch_item_idx, (input_ids_all, labels_all, pos_ids) in enumerate(zip(batch_input_ids, batch_labels, batch_pos_ids)):

if pos_ids is None:
pos_ranges = [(0, len(input_ids_all)-1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start:end+1]
labels = labels_all[start:end+1]
# input_ids[start:end] = tokenizer.pad_token_id

tokens_without_loss = (labels == IGNORE_INDEX)
tokens_with_loss = (labels != IGNORE_INDEX)
tokens_exclude_padding = (input_ids != tokenizer.pad_token_id)

prompt_token_includes = tokens_without_loss & tokens_exclude_padding

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
# completion_text = tokenizer.decode(completion_token_ids)
# completion_texts.append(completion_text)

# completion_logit = logits[tokens_with_loss]
# predicted_tokens = logits_to_tokens(completion_logit)
# prediction_text = tokenizer.decode(predicted_tokens)
# prediction_texts.append(prediction_text)

prompt_texts = tokenizer.batch_decode(prompt_token_ids_list, skip_special_tokens=True)
completion_texts = tokenizer.batch_decode(completion_token_ids_list, skip_special_tokens=True)

with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=32,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

encoding = tokenizer(prompt_texts, padding=True, return_tensors='pt').to(self.cfg.device)
new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) # FIXME: when sample_packing=True then error: "TypeError: varlen_fwd(): incompatible function arguments."
Glavin001 marked this conversation as resolved.
Show resolved Hide resolved

new_prediction_all_tokens = new_predictions["sequences"].cpu().tolist()
new_prediction_without_prompt_tokens_list = []
for prompt_token_ids, new_prediction_tokens in zip(prompt_token_ids_list, new_prediction_all_tokens):
new_prediction_without_prompt_tokens = new_prediction_tokens[len(prompt_token_ids):]
new_prediction_without_prompt_tokens_list.append(new_prediction_without_prompt_tokens)

new_predicted_texts = tokenizer.batch_decode(new_prediction_without_prompt_tokens_list, skip_special_tokens=True)

# for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)):
for i, (prompt_text, completion_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, new_predicted_texts)):
prediction_text = ""
table.add_data(i, prompt_text, completion_text, prediction_text, new_predicted_text)

batch_index += 1

wandb.run.log({ f"{name} - Predictions vs Ground Truth": table })

# log_table_from_dataloader("Train", train_dataloader)
# log_table_from_dataloader("Train", train_dataloader)

# # Get first 10 records from train_dataloader as a new dataloader
# train_data_subset = [next(iter(train_dataloader)) for _ in range(10)]
# train_dataloader_subset = torch.utils.data.DataLoader(train_data_subset, batch_size=train_dataloader.batch_size, shuffle=False)
# log_table_from_dataloader("Train Subset", train_dataloader_subset)

log_table_from_dataloader("Eval", eval_dataloader)

return control

return LogPredictionCallback


def group_sublists_by(lst: List[int], value: int) -> List[List[int]]:
"""
Group sublists of lst by value
"""
grouped = []
for key, group in itertools.groupby(lst, lambda x: x == value):
if key:
grouped.append(list(group))
return grouped
4 changes: 2 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def load_model(
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len

Expand Down
Loading
Loading