Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for alpaca reflect training #2

Merged
merged 1 commit into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions configs/vicuna_13B_4bit_reflect.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_4bit: true
gptq_groupsize: 128
gptq_model_v1: false
datasets:
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
- path: data/alpaca_reflect_pruned.jsonl
type: reflection
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-reflect
batch_size: 8
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
local_rank:
flash_attention: true
14 changes: 12 additions & 2 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
ShareGPTPromptTokenizingStrategy,
LLAMA_DEFAULT_PAD_TOKEN,
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
)
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter

logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
Expand Down Expand Up @@ -395,6 +395,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)

# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
Expand Down Expand Up @@ -540,6 +541,15 @@ def train(
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
Expand Down
61 changes: 61 additions & 0 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,67 @@ def parse_instruction_fields(self, prompt) -> (str, str, str):
)


class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
raise NotImplementedError

def tokenize_prompt(self, prompt):
instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.build_prompt(
instruction,
input,
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]

return tokenized_full_prompt

def _build_full_prompt(self, instruction, input, output, reflection, corrected):
return self.prompter.build_prompt(
instruction,
input,
output,
reflection,
corrected,
)

def _tokenize(self, prompt, add_eos_token=True):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)

result["labels"] = result["input_ids"].copy()
return result


class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
prompt["reflection"],
prompt["corrected"],
)

class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
try:
Expand Down
29 changes: 29 additions & 0 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,35 @@ class GPTeacherPrompter(AlpacaPrompter):
...


class ReflectAlpacaPrompter:
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
response_split = "### Response:"

def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
else:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
res = f"{res}{label}"
return res

def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()


class SeparatorStyle(Enum):
"""Different separator style."""

Expand Down