Skip to content

Commit

Permalink
CI add inference test for llama-2-7b-chat-hf (#144)
Browse files Browse the repository at this point in the history
Signed-off-by: jiafu zhang <jiafu.zhang@intel.com>
  • Loading branch information
jiafuzha committed Aug 21, 2023
1 parent 0803a01 commit 2f61886
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Chatbot inference on llama-2-7b-chat-hf

on:
workflow_call:

# If there is a new commit, the previous jobs will be canceled
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
inference:
name: inference test
runs-on: lms-lab
steps:
- name: Checkout
uses: actions/checkout@v2

- name: Build Docker Image
run: docker build ./ --target cpu --build-arg http_proxy="$HTTP_PROXY_IMAGE_BUILD" --build-arg https_proxy="$HTTPS_PROXY_IMAGE_BUILD" -f workflows/chatbot/inference/docker/Dockerfile -t chatbotinfer:latest && yes | docker container prune && yes | docker image prune

- name: Start Docker Container
run: |
cid=$(docker ps -q --filter "name=chatbotinfer")
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid; fi
docker run -tid -v /mnt/DP_disk1/huggingface/cache/:/root/.cache/huggingface/hub -v .:/root/chatbot -e http_proxy="$HTTP_PROXY_CONTAINER_RUN" -e https_proxy="$HTTPS_PROXY_CONTAINER_RUN" --name="chatbotinfer" --hostname="chatbotinfer-container" chatbotinfer:latest
- name: Run Inference Test
run: |
docker exec "chatbotinfer" bash -c "cd /root/chatbot && source activate && conda activate chatbot-demo; python workflows/chatbot/inference/generate.py --base_model_path \"meta-llama/Llama-2-7b-chat-hf\" --hf_access_token \"$HF_ACCESS_TOKEN\" --instructions \"Transform the following sentence into one that shows contrast. The tree is rotten.\" "
- name: Stop Container
if: success() || failure()
run: |
cid=$(docker ps -q --filter "name=chatbotinfer")
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid; fi
- name: Test Summary
run: echo "Inference completed successfully"
27 changes: 27 additions & 0 deletions .github/workflows/chatbot-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Chat Bot Test

on:
pull_request:
branches:
- main
paths:
- './requirements.txt'
- '.github/workflows/chatbot-test.yml'
- '.github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml'
- 'intel_extension_for_transformers/**'
- 'workflows/chatbot/inference/**'
- 'workflows/dlsa/**'
- 'workflows/hf_finetuning_and_inference_nlp/**'

workflow_dispatch:

# If there is a new commit, the previous jobs will be canceled
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:

call-inference-llama-2-7b-chat-hf:
uses: ./.github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml

3 changes: 1 addition & 2 deletions workflows/chatbot/inference/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ RUN source activate && conda activate chatbot-demo && \
conda install astunparse ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses -y && \
conda install jemalloc gperftools -c conda-forge -y && \
conda install pytorch torchvision torchaudio cpuonly -c pytorch && \
pip install farm-haystack==1.14.0 && \
pip install intel_extension_for_pytorch && \
pip install optimum-intel && \
pip install transformers diffusers accelerate SentencePiece peft evaluate nltk datasets && \
pip install fastapi uvicorn sse_starlette bottle gevent pymysql && \
pip install uvicorn sse_starlette bottle gevent pymysql && \
pip install schema && \
pip install datasets torch transformers sentencepiece peft evaluate nltk rouge_score && \
cd /root/chatbot && git clone https://github.com/intel/intel-extension-for-transformers.git \
Expand Down
16 changes: 13 additions & 3 deletions workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def parse_args():
default=128,
help="The maximum number of new tokens to generate.",
)
parser.add_argument(
"--hf_access_token",
type=str,
default=None,
help="Huggingface token to access model",
)
parser.add_argument(
"--num_beams",
type=int,
Expand Down Expand Up @@ -331,6 +337,7 @@ def load_model(
use_cache=True,
peft_path=None,
use_deepspeed=False,
hf_access_token=None,
):
"""
Load the model and initialize the tokenizer.
Expand Down Expand Up @@ -363,11 +370,12 @@ def load_model(
tokenizer_name,
use_fast=False if (re.search("llama", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)) else True,
token=hf_access_token,
)
if re.search("flan-t5", model_name, re.IGNORECASE):
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, low_cpu_mem_usage=True
model_name, low_cpu_mem_usage=True, token=hf_access_token
)
elif (re.search("mpt", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)):
Expand All @@ -380,6 +388,7 @@ def load_model(
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
torchscript=cpu_jit,
token=hf_access_token,
)
elif (
re.search("gpt", model_name, re.IGNORECASE)
Expand All @@ -390,7 +399,7 @@ def load_model(
):
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, token=hf_access_token
)
else:
raise ValueError(
Expand Down Expand Up @@ -468,7 +477,7 @@ def load_model(
from models.mpt.mpt_trace import jit_trace_mpt_7b, MPTTSModelForCausalLM

model = jit_trace_mpt_7b(model)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=hf_access_token)
model = MPTTSModelForCausalLM(
model, config, use_cache=use_cache, model_dtype=torch.bfloat16
)
Expand Down Expand Up @@ -957,6 +966,7 @@ def main():
use_cache=args.use_kv_cache,
peft_path=args.peft_model_path,
use_deepspeed=True if use_deepspeed and args.habana else False,
hf_access_token=args.hf_access_token
)

if args.habana:
Expand Down

0 comments on commit 2f61886

Please sign in to comment.