diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5f0885484..1dbff114e 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,7 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9514208b1..87b308362 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: @@ -80,7 +80,7 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: diff --git a/docker/Dockerfile b/docker/Dockerfile index 41915de83..81a08bc8b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets -RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ else \ diff --git a/setup.py b/setup.py index 42fd22df1..fe4d2cfad 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ """setup.py for axolotl""" +from importlib.metadata import PackageNotFoundError, version + from setuptools import find_packages, setup @@ -22,12 +24,13 @@ def parse_requirements(): # Handle standard packages _install_requires.append(line) - # TODO(wing) remove once xformers release supports torch 2.1.0 - if "torch==2.1.0" in _install_requires: - _install_requires.pop(_install_requires.index("xformers>=0.0.22")) - _install_requires.append( - "xformers @ git+https://github.com/facebookresearch/xformers.git@main" - ) + try: + torch_version = version("torch") + if torch_version.startswith("2.1.1"): + _install_requires.pop(_install_requires.index("xformers==0.0.22")) + _install_requires.append("xformers==0.0.23") + except PackageNotFoundError: + pass return _install_requires, _dependency_links