Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added Gradio support #812

Merged
merged 3 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out"

# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
```

## Installation
Expand Down Expand Up @@ -918,6 +922,10 @@ Pass the appropriate flag to the train command:
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```

Please use `--sample_packing False` if you have it on and receive the error similar to below:

Expand Down
1 change: 0 additions & 1 deletion gitbook/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
# Page

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
gradio
84 changes: 83 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import random
import sys
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union

import gradio as gr
import torch
import yaml

Expand All @@ -16,7 +18,7 @@
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
Expand Down Expand Up @@ -153,6 +155,86 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))


def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})

prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)

if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id

set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)

model = model.to(cfg.device)

def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

all_text = ""

for new_text in streamer:
all_text += new_text
yield all_text

demo = gr.Interface(fn=generate, inputs="text", outputs="text")
demo.launch(show_api=False, share=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
demo.launch(show_api=False, share=True)
demo.queue().launch(show_api=False, share=True)

Ran into this error: ValueError: Queue needs to be enabled! You may get this error by either 1) passing a function that uses the yield keyword into an interface without enabling the queue or 2) defining an event that cancels another event without enabling the queue. Both can be solved by calling .queue() before .launch(), the above change seems to fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird that it didn't show up for me... Added it anyways and also a --gradio_title param



def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))

Expand Down
14 changes: 11 additions & 3 deletions src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import fire
import transformers

from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
Expand All @@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
)
parsed_cli_args.inference = True

do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)


if __name__ == "__main__":
Expand Down