Skip to content

Commit

Permalink
Merge pull request #185 from jrzaurin/flash_attention
Browse files Browse the repository at this point in the history
Flash attention
  • Loading branch information
jrzaurin authored Aug 4, 2023
2 parents cd1ff79 + 67439c4 commit 2ef478c
Show file tree
Hide file tree
Showing 67 changed files with 7,302 additions and 1,293 deletions.
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,33 @@ passed through a series of ResNet blocks built with dense layers.
3. **TabNet**: details on TabNet can be found in
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)

Two simpler attention based models that we call:

4. **ContextAttentionMLP**: MLP with at attention mechanism "on top" that is based on
[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pd)
5. **SelfAttentionMLP**: MLP with an attention mechanism that is a simplified
version of a transformer block that we refer as "query-key self-attention".

The ``Tabformer`` family, i.e. Transformers for Tabular data:

4. **TabTransformer**: details on the TabTransformer can be found in
6. **TabTransformer**: details on the TabTransformer can be found in
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
5. **SAINT**: Details on SAINT can be found in
7. **SAINT**: Details on SAINT can be found in
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
6. **FT-Transformer**: details on the FT-Transformer can be found in
8. **FT-Transformer**: details on the FT-Transformer can be found in
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
9. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
on the Fasformer can be found in
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
10. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
the Perceiver can be found in
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)

And probabilistic DL models for tabular data based on
[Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424):

9. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
10. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model
11. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
12. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model

Note that while there are scientific publications for the TabTransformer,
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
Expand Down Expand Up @@ -196,7 +203,6 @@ using `Wide` and `DeepDense` and defaults settings.
Building a wide (linear) and deep model with ``pytorch-widedeep``:

```python
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.3.1
1.3.2
84 changes: 84 additions & 0 deletions examples/scripts/adult_census_linear_and_flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from time import time

from sklearn.model_selection import train_test_split

from pytorch_widedeep import Trainer
from pytorch_widedeep.models import WideDeep, TabTransformer
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor

# use_cuda = torch.cuda.is_available()

df = load_adult(as_frame=True)
df.columns = [c.replace("-", "_") for c in df.columns]
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
target_colname = "income_label"

cat_embed_cols = []
for col in df.columns:
if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname:
cat_embed_cols.append(col)

train, test = train_test_split(
df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
)

with_cls_token = True
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, with_attention=True, with_cls_token=with_cls_token
)

X_tab_train = tab_preprocessor.fit_transform(train)
X_tab_test = tab_preprocessor.transform(test)
target = train[target_colname].values


tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
input_dim=16,
n_heads=2,
n_blocks=2,
)

linear_tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
input_dim=16,
n_heads=2,
n_blocks=2,
use_linear_attention=True,
)

flash_tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
input_dim=16,
n_heads=2,
n_blocks=2,
use_flash_attention=True,
)

s_model = WideDeep(deeptabular=tab_transformer)
l_model = WideDeep(deeptabular=linear_tab_transformer)
f_model = WideDeep(deeptabular=flash_tab_transformer)

for name, model in [("standard", s_model), ("linear", l_model), ("flash", f_model)]:
trainer = Trainer(
model,
objective="binary",
metrics=[Accuracy],
)

s = time()
trainer.fit(
X_tab=X_tab_train,
target=target,
n_epochs=1,
batch_size=64,
val_split=0.2,
)
e = time() - s
print(f"{name} attention time: {round(e, 3)} secs")
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pytorch_widedeep.datasets import load_movielens100k

data, user, items = load_movielens100k(as_frame=True)
data, users, items = load_movielens100k(as_frame=True)

# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:
# list_of_genres = items.columns.tolist()[-19:]
Expand All @@ -37,7 +37,7 @@
]


# adding a column with the number of movies watched per user
# adding a column with the number of movies watched per users
dataset = data.sort_values(["user_id", "timestamp"]).reset_index(drop=True)
dataset["one"] = 1
dataset["num_watched"] = dataset.groupby("user_id")["one"].cumsum()
Expand All @@ -61,6 +61,9 @@
)
dataset["prev_movies"] = dataset["prev_movies"].apply(lambda x: x.split())

# Adding user feats
dataset = dataset.merge(users, on="user_id", how="left")

# Adding a genre_rate as the mean of all movies rated for a given genre per
# user
dataset = dataset.merge(items[["movie_id"] + list_of_genres], on="movie_id", how="left")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
).to_list()
y_train = df_train.target.values.astype(int)

df_test_user_item = df_train[["user_id", "movie_id", "rating"]]
df_test_user_item = df_test[["user_id", "movie_id", "rating"]]
test_movies_sequences = df_test.prev_movies.apply(
lambda x: [int(el) for el in x]
).to_list()
Expand Down Expand Up @@ -89,7 +89,7 @@
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
mlp_hidden_dims=[1024, 512, 256],
mlp_hidden_dims=[512, 256],
mlp_activation="relu",
)

Expand Down Expand Up @@ -124,7 +124,7 @@
"X_text": X_test_text,
"target": y_test,
},
n_epochs=10,
batch_size=521,
n_epochs=2,
batch_size=32,
shuffle=False,
)
3 changes: 3 additions & 0 deletions mkdocs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ nav:
- 16_Self-Supervised Pre-Training pt 1: examples/16_Self_Supervised_Pretraning_pt1.ipynb
- 16_Self-Supervised Pre-Training pt 2: examples/16_Self_Supervised_Pretraning_pt2.ipynb
- 17_Using_a_huggingface_model: examples/17_Usign_a_hugging_face_model.ipynb
- 18_feature_importance_via_attention_weights: examples/18_feature_importance_via_attention_weights.ipynb
- 19_wide_and_deep_for_recsys_pt1: examples/19_wide_and_deep_for_recsys_pt1.ipynb
- 19_wide_and_deep_for_recsys_pt2: examples/19_wide_and_deep_for_recsys_pt2.ipynb
- Contributing: contributing.md

theme:
Expand Down
48 changes: 48 additions & 0 deletions mkdocs/site/404.html
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,12 @@












Expand Down Expand Up @@ -1012,6 +1018,48 @@








<li class="md-nav__item">
<a href="/examples/18_feature_importance_via_attention_weights.html" class="md-nav__link">
18_feature_importance_via_attention_weights
</a>
</li>









<li class="md-nav__item">
<a href="/examples/19_wide_and_deep_for_recsys_pt1.html" class="md-nav__link">
19_wide_and_deep_for_recsys_pt1
</a>
</li>









<li class="md-nav__item">
<a href="/examples/19_wide_and_deep_for_recsys_pt2.html" class="md-nav__link">
19_wide_and_deep_for_recsys_pt2
</a>
</li>




</ul>
</nav>
</li>
Expand Down
52 changes: 50 additions & 2 deletions mkdocs/site/contributing.html
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,12 @@












Expand Down Expand Up @@ -1016,6 +1022,48 @@








<li class="md-nav__item">
<a href="examples/18_feature_importance_via_attention_weights.html" class="md-nav__link">
18_feature_importance_via_attention_weights
</a>
</li>









<li class="md-nav__item">
<a href="examples/19_wide_and_deep_for_recsys_pt1.html" class="md-nav__link">
19_wide_and_deep_for_recsys_pt1
</a>
</li>









<li class="md-nav__item">
<a href="examples/19_wide_and_deep_for_recsys_pt2.html" class="md-nav__link">
19_wide_and_deep_for_recsys_pt2
</a>
</li>




</ul>
</nav>
</li>
Expand Down Expand Up @@ -1095,7 +1143,7 @@ <h1>Contributing</h1>
<nav class="md-footer__inner md-grid" aria-label="Footer" >


<a href="examples/17_Usign_a_hugging_face_model.html" class="md-footer__link md-footer__link--prev" aria-label="Previous: 17_Using_a_huggingface_model" rel="prev">
<a href="examples/19_wide_and_deep_for_recsys_pt2.html" class="md-footer__link md-footer__link--prev" aria-label="Previous: 19_wide_and_deep_for_recsys_pt2" rel="prev">
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12Z"/></svg>
</div>
Expand All @@ -1104,7 +1152,7 @@ <h1>Contributing</h1>
<span class="md-footer__direction">
Previous
</span>
17_Using_a_huggingface_model
19_wide_and_deep_for_recsys_pt2
</div>
</div>
</a>
Expand Down
Loading

0 comments on commit 2ef478c

Please sign in to comment.