Skip to content

Commit

Permalink
Added gradio support
Browse files Browse the repository at this point in the history
  • Loading branch information
Stillerman committed Nov 2, 2023
1 parent 2e71ff0 commit 7ec7c4a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 4 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out"

# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
```

## Installation
Expand Down Expand Up @@ -918,6 +922,10 @@ Pass the appropriate flag to the train command:
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```

Please use `--sample_packing False` if you have it on and receive the error similar to below:

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
gradio
77 changes: 76 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""

from threading import Thread
import importlib
import logging
import os
import random
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import gradio as gr

import torch
import yaml
Expand All @@ -16,7 +18,7 @@
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer
from transformers import GenerationConfig, TextStreamer, TextIteratorStreamer

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
Expand Down Expand Up @@ -152,6 +154,79 @@ def do_inference(
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))

def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})

prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)

if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id

set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)

model = model.to(cfg.device)

def generate(instruction):
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(inputs=batch["input_ids"].to(cfg.device), generation_config=generation_config, streamer=streamer)

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text

demo = gr.Interface(fn=generate, inputs="text", outputs="text")
demo.launch(show_api=False, share=True)


def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
Expand Down
9 changes: 6 additions & 3 deletions src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import fire
import transformers

from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
from axolotl.cli import do_inference, do_inference_gradio, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
Expand All @@ -21,7 +21,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
)
parsed_cli_args.inference = True

do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)


if __name__ == "__main__":
Expand Down

0 comments on commit 7ec7c4a

Please sign in to comment.