Skip to content

Commit

Permalink
Add unit test verifying compatibility with huggingface models (#1352)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1352

Our current unit tests for LLM Attribution use mocked models which are similar to huggingface transformer models (e.g. Llama, Llama2), but may have some unexpected differences such as [this](https://discuss.pytorch.org/t/trying-to-explain-zephyr-generative-llm/195262/3?fbclid=IwZXh0bgNhZW0CMTEAAR3REGbJsdhbNqG5LAyQ9_2J-82nPmNjt5avVyvNw-l8SMTWVXfI2DqIE8w_aem_GRP8EzELKtqDXDMZmox3Uw). To validate coverage and ensure compatibility with future changes to models, we would like to add tests using huggingface models directly and validate compatibility with LLM Attribution, which will help us quickly catch any breaking changes.

So far we only test for model type `LlamaForCausalLM`

Reviewed By: vivekmig

Differential Revision: D62894898
  • Loading branch information
DianjingLiu authored and facebook-github-bot committed Sep 19, 2024
1 parent fc910e5 commit 1be5d86
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fi
# install other deps
conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug==2.2.2
conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress
conda install -q -y transformers

# install captum
python setup.py develop
2 changes: 2 additions & 0 deletions scripts/install_via_pip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ fi
if [[ $DEPLOY == true ]]; then
pip install beautifulsoup4 ipython nbconvert==5.6.1 --progress-bar off
fi

pip install transformers --progress-bar off
91 changes: 91 additions & 0 deletions tests/attr/test_llm_attr_hf_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3


from typing import cast, Dict, Optional, Type

import torch
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.interpretable_input import TextTemplateInput
from parameterized import parameterized, parameterized_class
from tests.helpers import BaseTest
from torch import Tensor

HAS_HF = True
try:
# pyre-fixme[21]: Could not find a module corresponding to import `transformers`
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
HAS_HF = False


@parameterized_class(
("device", "use_cached_outputs"),
(
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
if torch.cuda.is_available()
else [("cpu", True), ("cpu", False)]
),
)
class TestLLMAttrHFCompatibility(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
# pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized.
use_cached_outputs: bool

def setUp(self) -> None:
if not HAS_HF:
self.skipTest("transformers package not found, skipping tests")
super().setUp()

# pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension
@parameterized.expand(
[
(
AttrClass,
n_samples,
)
for AttrClass, n_samples in zip(
(FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass
(None, 1000, None), # n_samples
)
]
)
def test_llm_attr_hf_compatibility(
self,
AttrClass: Type[PerturbationAttribution],
n_samples: Optional[int],
) -> None:
attr_kws: Dict[str, int] = {}
if n_samples is not None:
attr_kws["n_samples"] = n_samples

tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
llm = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)

llm.to(self.device)
llm.eval()
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_attr.attribute(
inp,
"m n o p q",
use_cached_outputs=self.use_cached_outputs,
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
# for 4th positional argument, expected
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
**attr_kws, # type: ignore
)
self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(len(res.output_tokens), 5)
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

0 comments on commit 1be5d86

Please sign in to comment.