Skip to content

Commit

Permalink
add intel xpu support for TGI
Browse files Browse the repository at this point in the history
Signed-off-by: xiaolil1 <xiaoli.liu@intel.com>
Signed-off-by: ganyi <yi.gan@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Jan 24, 2024
1 parent da27fbd commit 1fe5c05
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 74 deletions.
72 changes: 72 additions & 0 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

FROM chef as planner
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json

COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo build --release

# Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base

USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb

# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80


# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_common.txt && \
pip install ".[accelerate, peft]" --no-cache-dir

# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher

# Final image
FROM base

ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
6 changes: 5 additions & 1 deletion server/text_generation_server/models/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM

BLOCK_SIZE: int = 16
# Will be set in warmup
Expand All @@ -24,7 +25,10 @@ def __init__(
self.repeat_slots = repeat_slots

element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
if IS_XPU_SYSTEM:
x = 1
else:
x = self.block_size // element_size

self.kv_cache = [
(
Expand Down
28 changes: 20 additions & 8 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from text_generation_server.utils.dist import MEMORY_FRACTION

tracer = trace.get_tracer(__name__)

from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM

@dataclass
class FlashCausalLMBatch(Batch):
Expand Down Expand Up @@ -679,7 +679,10 @@ def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch

def warmup(self, batch: FlashCausalLMBatch):
torch.cuda.empty_cache()
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try:
cache_manager = set_cache_manager(
batch.blocks,
Expand All @@ -697,20 +700,29 @@ def warmup(self, batch: FlashCausalLMBatch):
f"You need to decrease `--max-batch-prefill-tokens`"
) from e

torch.cuda.synchronize(self.device)
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)

# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory *0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")

num_blocks = (
int(free_memory // total_cache_size)
Expand Down
4 changes: 4 additions & 0 deletions server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

tracer = trace.get_tracer(__name__)

from text_generation_server.utils.import_utils import IS_XPU_SYSTEM

class FlashLlama(FlashCausalLM):
def __init__(
Expand All @@ -34,6 +35,9 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")

Expand Down
6 changes: 5 additions & 1 deletion server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM


# Adds windowing logic to FlashCausalLMBatch
Expand Down Expand Up @@ -302,8 +303,11 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
raise NotImplementedError("FlashMistral is only available on GPU")

tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
Expand Down
5 changes: 4 additions & 1 deletion server/text_generation_server/models/flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
weight_files,
Weights,
)

from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)


Expand All @@ -31,6 +31,9 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")

Expand Down
5 changes: 4 additions & 1 deletion server/text_generation_server/models/flash_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
weight_files,
Weights,
)

from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)


Expand All @@ -32,6 +32,9 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")

Expand Down
4 changes: 4 additions & 0 deletions server/text_generation_server/models/flash_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Weights,
)

from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)


Expand All @@ -34,6 +35,9 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

Expand Down
9 changes: 8 additions & 1 deletion server/text_generation_server/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
try:
import oneccl_bindings_for_pytorch

backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None

if WORLD_SIZE == 1:
Expand Down
Loading

0 comments on commit 1fe5c05

Please sign in to comment.