Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm support #295

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions .github/workflows/build_rocm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
name: Build and push AMD ROCm docker image to registry

on:
workflow_dispatch:
push:
branches:
- 'main'
tags:
- 'v*'
pull_request:
paths:
- ".github/workflows/build.yaml"
# - "integration-tests/**"
- "backends/**"
- "core/**"
- "router/**"
- "Cargo.lock"
- "rust-toolchain.toml"
- "Dockerfile"
branches:
- 'main'

jobs:
build-and-push-image:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
permissions:
contents: write
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Tailscale
uses: huggingface/tailscale-action@v1
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}

- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
with:
install: true
config-inline: |
[registry."docker.io"]
mirrors = ["registry.github-runners.huggingface.tech"]

- name: Configure sccache
uses: actions/github-script@v6
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');

- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1

- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Login to internal Container Registry
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech

- name: Extract metadata (tags, labels) for Docker
id: meta-rocm
uses: docker/metadata-action@v4.3.0
with:
images: |
registry.internal.huggingface.tech/api-inference/text-embeddings-inference
ghcr.io/huggingface/text-embeddings-inference
flavor: |
latest=false
tags: |
type=semver,pattern=rocm-{{version}}
type=semver,pattern=rocm-{{major}}.{{minor}}
type=raw,value=rocm-latest
type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}

- name: Build and push Docker image
id: build-and-push-rocm
uses: docker/build-push-action@v4
with:
context: .
file: Dockerfile-cuda
push: ${{ github.event_name != 'pull_request' }}
platforms: 'linux/amd64'
build-args: |
SCCACHE_GHA_ENABLED=on
ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }}
ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }}
GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
tags: ${{ steps.meta-rocm.outputs.tags }}
labels: ${{ steps.meta-rocm.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max

- name: Extract metadata (tags, labels) for Docker
id: meta-rocm-grpc
uses: docker/metadata-action@v4.3.0
with:
images: |
registry.internal.huggingface.tech/api-inference/text-embeddings-inference
ghcr.io/huggingface/text-embeddings-inference
flavor: |
latest=false
tags: |
type=semver,pattern=rocm-{{version}}-grpc
type=semver,pattern=rocm-{{major}}.{{minor}}-grpc
type=raw,value=rocm-latest-grpc
type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc

- name: Build and push Docker image
id: build-and-push-rocm-grpc
uses: docker/build-push-action@v4
with:
context: .
target: grpc
file: Dockerfile-cuda
push: ${{ github.event_name != 'pull_request' }}
platforms: 'linux/amd64'
build-args: |
SCCACHE_GHA_ENABLED=on
ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }}
ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }}
GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
tags: ${{ steps.meta-rocm-grpc.outputs.tags }}
labels: ${{ steps.meta-rocm-grpc.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.idea
target
__pycache__/
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ WORKDIR /usr/src
ENV SCCACHE=0.5.4
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

# Donwload, configure sccache
# Download, configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

Expand Down
135 changes: 135 additions & 0 deletions Dockerfile-rocm
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
FROM rocm/dev-ubuntu-22.04:6.0.2 AS base-builder

ENV SCCACHE=0.5.4
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
ENV PATH="/root/.cargo/bin:${PATH}"

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
libssl-dev \
pkg-config \
&& rm -rf /var/lib/apt/lists/*

# Donwload and configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
RUN cargo install cargo-chef --locked

FROM base-builder AS planner

WORKDIR /usr/src

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json

FROM base-builder AS builder

ARG CUDA_COMPUTE_CAP=80
ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

WORKDIR /usr/src

COPY --from=planner /usr/src/recipe.json recipe.json

RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
unzip \
&& rm -rf /var/lib/apt/lists/*

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

COPY proto proto

FROM builder as http-builder

RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

FROM builder as grpc-builder

RUN cargo build --release --bin text-embeddings-router -F python -F grpc --no-default-features && sccache -s

FROM rocm/dev-ubuntu-22.04:6.0.2 as base

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
git \
python3-dev \
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipblaslt-dev \
rocblas-dev \
hiprand-dev \
rocrand-dev \
&& rm -rf /var/lib/apt/lists/*


# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH

RUN curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
mamba init && \
rm ~/mambaforge.sh

# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir

RUN pip install torch --index-url https://download.pytorch.org/whl/rocm6.0

ARG DEFAULT_USE_FLASH_ATTENTION=True
COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2
RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm

# Install python backend
COPY backends/python/server /tei_backends/python/server
COPY backends/proto tei_backends/proto
RUN make -C /tei_backends/python/server install

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80 \
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION

FROM base as grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM base

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
21 changes: 21 additions & 0 deletions backends/python/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
flash_att_v2_commit_cuda := v2.5.9.post1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6

build-flash-attention-v2-cuda:
pip install -U packaging wheel
pip install flash-attn==$(flash_att_v2_commit_cuda)

install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
echo "Flash v2 installed"

build-flash-attention-v2-rocm:
if [ ! -d 'flash-attention-v2' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
fi

install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
cd flash-attention-v2 && \
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
10 changes: 8 additions & 2 deletions backends/python/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
safetensors = "^0.3.2"
safetensors = "^0.4.0"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
torch = { version = "^2.0.1" }
torch = { version = "==2.3.1" }
transformers = { version = "^4.39.0"}

[tool.poetry.extras]

Expand All @@ -33,6 +34,11 @@ name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu118"
priority = "explicit"

[[tool.poetry.source]]
name = "pytorch-gpu-src-rocm"
url = "https://download.pytorch.org/whl/rocm6.0"
priority = "explicit"

[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

Expand Down
12 changes: 0 additions & 12 deletions backends/python/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,13 @@ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -27,15 +20,10 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
3 changes: 2 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def serve(
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-embeddings-inference.server",
pooling_mode: Optional[str] = None,
):
# Remove default handler
logger.remove()
Expand All @@ -48,7 +49,7 @@ def serve(
# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path)
server.serve(model_path, dtype, uds_path, pooling_mode)


if __name__ == "__main__":
Expand Down
Empty file.
Loading
Loading