Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notebooks #584

Merged
merged 6 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tutorials/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,14 @@ The datasets used in the tutorials are represented with their respective logo:
| [Stanford sentiment treebank](https://nlp.stanford.edu/sentiment/index.html) | <img width="25" alt="nlp-logo_half_size" src="https://user-images.githubusercontent.com/3244249/152540890-c8e1e37d-f0cc-4f84-80a4-2c59176cbf4c.png">|

The models used in the tutorials are available at [tutorials/models](https://github.com/dianna-ai/dianna/tree/main/tutorials/models).


## Colab
The tutorials can also be run directly in Google Colab, by clicking on the links/buttons below, or for a general demo here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/demo.ipynb).

| modality \ method | RISE | LIME | KernelSHAP |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the table, but as a user all the Open in Colab buttons do not help me identify the notebooks. Could you change these to reflect the title or theme of the notebook?

I don't think it's possible to customize the badges (maybe https://shields.io has some). Otherwise, could you just put the link, for example:

RISE - MNIST dataset, RISE - Imagenet dataset

|-------------------|------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
| images | [mnist](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_mnist.ipynb), [imagenet](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_imagenet.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_images.ipynb) | [mnist](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/kernelshap_mnist.ipynb), [geometric shapes](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/kernelshap_geometric_shapes.ipynb) |
| text | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_text.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_text.ipynb) | - |
| timeseries | [weather](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_timeseries_weather.ipynb) | [weather](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_timeseries_weather.ipynb), [coffee](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_timeseries_coffee.ipynb) | - |

35 changes: 31 additions & 4 deletions tutorials/demo.ipynb

Large diffs are not rendered by default.

112 changes: 87 additions & 25 deletions tutorials/kernelshap_geometric_shapes.ipynb

Large diffs are not rendered by default.

106 changes: 82 additions & 24 deletions tutorials/kernelshap_mnist.ipynb

Large diffs are not rendered by default.

65 changes: 39 additions & 26 deletions tutorials/lime_images.ipynb

Large diffs are not rendered by default.

116 changes: 93 additions & 23 deletions tutorials/lime_text.ipynb

Large diffs are not rendered by default.

36 changes: 35 additions & 1 deletion tutorials/lime_timeseries_coffee.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,40 @@
"*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the LIME parameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Colab Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://github.com/raw/dianna-ai/dianna/main/tutorials/'\n",
" paths_to_download = ['data/coffee_train.csv', 'data/coffee_test.csv', 'models/coffee.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -442,7 +476,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.12"
},
"orig_nbformat": 4
},
Expand Down
36 changes: 35 additions & 1 deletion tutorials/lime_timeseries_weather.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,40 @@
"*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the LIME parameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Colab Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://github.com/raw/dianna-ai/dianna/main/tutorials/'\n",
" paths_to_download = ['models/season_prediction_model_temp_max_binary.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -1065,7 +1099,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.12"
},
"orig_nbformat": 4
},
Expand Down
104 changes: 69 additions & 35 deletions tutorials/rise_imagenet.ipynb

Large diffs are not rendered by default.

76 changes: 61 additions & 15 deletions tutorials/rise_mnist.ipynb

Large diffs are not rendered by default.

124 changes: 102 additions & 22 deletions tutorials/rise_text.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,35 @@
"*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the RISE parameters"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "40dc5e32",
"metadata": {},
"source": [
"#### Colab Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "236ca562",
"metadata": {},
"outputs": [],
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://github.com/raw/dianna-ai/dianna/main/tutorials/'\n",
" paths_to_download = ['data/movie_reviews_word_vectors.txt', 'models/movie_review_model.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
]
},
{
"cell_type": "markdown",
"id": "a5cf6f82-c1c7-4814-ae0f-5a1c0b8578f6",
Expand All @@ -27,10 +56,19 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "34b556d8-5337-44dc-8efe-14d1dff6f011",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -48,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "c616916c-78ef-48d0-a744-b25b37b62a3f",
"metadata": {},
"outputs": [],
Expand All @@ -71,20 +109,62 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "486540bd-2676-4dfa-bbe8-ee8aa289acd3",
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting en-core-web-sm==3.2.0\n",
" Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.9/13.9 MB 2.2 MB/s eta 0:00:00\n",
"Requirement already satisfied: spacy<3.3.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from en-core-web-sm==3.2.0) (3.2.4)\n",
"Requirement already satisfied: click<8.1.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.4)\n",
"Requirement already satisfied: blis<0.8.0,>=0.4.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.7.7)\n",
"Requirement already satisfied: numpy>=1.15.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.21.6)\n",
"Requirement already satisfied: jinja2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.1.1)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.9)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.6)\n",
"Requirement already satisfied: packaging>=20.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (21.3)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.6)\n",
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3.0)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.27.1)\n",
"Requirement already satisfied: pathy>=0.3.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.6.1)\n",
"Requirement already satisfied: setuptools in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (62.1.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.6)\n",
"Requirement already satisfied: thinc<8.1.0,>=8.0.12 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.15)\n",
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.7)\n",
"Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.4.3)\n",
"Requirement already satisfied: typer<0.5.0,>=0.3.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.4.1)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.64.0)\n",
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.2)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.9.1)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.8.2)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from packaging>=20.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.8)\n",
"Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pathy>=0.3.5->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (5.2.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.26.9)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.12)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2021.10.8)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from jinja2->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.1.1)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the package via spacy.load('en_core_web_sm')\n"
]
}
],
"source": [
"# ensure the tokenizer for english is available\n",
"spacy.cli.download('en_core_web_sm')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "555842c5-3f82-4f63-93bb-696645d4b447",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -126,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "443e8a99-6fa3-4a73-9311-2fbe0251c2b1",
"metadata": {},
"outputs": [],
Expand All @@ -152,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "7fc6ebcb-2328-4c06-ae67-c5590032eb69",
"metadata": {},
"outputs": [],
Expand All @@ -162,7 +242,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "7c0bfd7d-df1d-4981-b714-496bc16b9347",
"metadata": {},
"outputs": [
Expand All @@ -177,23 +257,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Explaining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:17<00:00, 1.72s/it]\n"
"Explaining: 100%|██████████| 10/10 [00:03<00:00, 2.75it/s]\n"
]
},
{
"data": {
"text/plain": [
"[('A', 0, 0.7158780014514923),\n",
" ('delectable', 1, 0.913871341049671),\n",
" ('and', 2, 0.6892129376530648),\n",
" ('intriguing', 3, 1.0620161551237106),\n",
" ('thriller', 4, 0.840078490972519),\n",
" ('filled', 5, 0.6051010835170746),\n",
" ('with', 6, 0.6926153092086315),\n",
" ('surprises', 7, 0.6697717276215553)]"
"[('A', 0, 0.5653130280971527),\n",
" ('delectable', 1, 0.8641307824850082),\n",
" ('and', 2, 0.7081780250370502),\n",
" ('intriguing', 3, 1.004394978582859),\n",
" ('thriller', 4, 0.9396217280626297),\n",
" ('filled', 5, 0.6516930902004242),\n",
" ('with', 6, 0.7476113395392894),\n",
" ('surprises', 7, 0.7425235873460769)]"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -217,14 +297,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "0136005d-a22f-43a0-80da-4ec1f283f870",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<html><body><span style=\"background:rgba(255, 0, 0, 0.54)\">A</span> <span style=\"background:rgba(255, 0, 0, 0.69)\">delectable</span> <span style=\"background:rgba(255, 0, 0, 0.52)\">and</span> <span style=\"background:rgba(255, 0, 0, 0.80)\">intriguing</span> <span style=\"background:rgba(255, 0, 0, 0.63)\">thriller</span> <span style=\"background:rgba(255, 0, 0, 0.46)\">filled</span> <span style=\"background:rgba(255, 0, 0, 0.52)\">with</span> <span style=\"background:rgba(255, 0, 0, 0.50)\">surprises</span></body></html>"
"<mark style=\"background-color: hsl(0, 100%, 72%, 0.8); line-height:1.75\">A</mark> <mark style=\"background-color: hsl(0, 100%, 57%, 0.8); line-height:1.75\">delectable</mark> <mark style=\"background-color: hsl(0, 100%, 65%, 0.8); line-height:1.75\">and</mark> <mark style=\"background-color: hsl(0, 100%, 50%, 0.8); line-height:1.75\">intriguing</mark> <mark style=\"background-color: hsl(0, 100%, 54%, 0.8); line-height:1.75\">thriller</mark> <mark style=\"background-color: hsl(0, 100%, 68%, 0.8); line-height:1.75\">filled</mark> <mark style=\"background-color: hsl(0, 100%, 63%, 0.8); line-height:1.75\">with</mark> <mark style=\"background-color: hsl(0, 100%, 64%, 0.8); line-height:1.75\">surprises</mark>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down Expand Up @@ -263,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
Loading