Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch2 (#177) #178

Merged
merged 7 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ jobs:
strategy:
matrix:
include:
- name: 'cpu'
container: mosaicml/pytorch:latest
- name: 'cpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ jobs:
strategy:
matrix:
include:
- name: 'gpu'
container: mosaicml/pytorch:latest
- name: 'gpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ jobs:
PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)"
fi

# Remove the xentropy-cuda-lib dependency as PyPI does not support direct installs. The
# error message for importing FusedCrossEntropy gives instructions on how to install if a
# user tries to use it without this dependency.
# Remove the xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not support
# direct installs. The error message for importing FusedCrossEntropy gives instructions
# on how to install if a user tries to use it without this dependency.
sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py
sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py

python -m pip install --upgrade build twine
python -m build
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
default_language_version:
python: python3
exclude: llmfoundry/models/layers/flash_attn_triton.py
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ Here's what you need to get started with our LLM stack:

# Installation

This assumes you already have PyTorch and CMake installed.

To get started, clone this repo and install the requirements:

<!--pytest.mark.skip-->
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers import flash_attn_triton
from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
Expand All @@ -9,6 +10,7 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm

__all__ = [
'flash_attn_triton',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,13 @@ def triton_flash_attn_fn(
multiquery=False,
):
try:
from flash_attn import flash_attn_triton # type: ignore
from llmfoundry.models.layers import flash_attn_triton # type: ignore
except:
raise RuntimeError(
'Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202'
)
raise ValueError(
'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
'and `pip install .[gpu]` if installing from source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
'Note: (1) requires you have CMake and PyTorch already installed.')

check_valid_inputs(query, key, value)

Expand Down
Loading