Skip to content

Commit

Permalink
Enable model_id for TCAV (#811)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #811

1. Added model_id to TCAV and removed hard coded: DEFAULT_MODEL_ID.
2. Saves cavs in a model_id subfolder instead of the generic CAVs directory.
3. Created CAV related subfolders in advance before generating cavs in single or multi-processing setup.
4. Added test case for model_id
5. Removed av.py in concepts package since we are now using av.py from utils
6. Add captum logging to TCAV

Reviewed By: bilalsal

Differential Revision: D32815352

fbshipit-source-id: ec570450c98a66e183aac9dfd4e2b8b76299ef4f
  • Loading branch information
NarineK authored and facebook-github-bot committed Dec 6, 2021
1 parent 98f88e0 commit e5a7d65
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 241 deletions.
4 changes: 3 additions & 1 deletion captum/_utils/av.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def generate_activation(
]
if len(unsaved_layers) > 0:
layer_act = LayerActivation(model, layer_modules)
new_activations = layer_act.attribute(inputs, additional_forward_args)
new_activations = layer_act.attribute.__wrapped__( # type: ignore
layer_act, inputs, additional_forward_args
)
AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id)

activations: List[Union[Tensor, Tuple[Tensor, ...]]] = []
Expand Down
1 change: 0 additions & 1 deletion captum/concept/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
from captum.concept._core.av import AV # noqa
from captum.concept._core.cav import CAV # noqa
from captum.concept._core.concept import Concept, ConceptInterpreter # noqa
from captum.concept._core.tcav import TCAV # noqa
Expand Down
198 changes: 0 additions & 198 deletions captum/concept/_core/av.py

This file was deleted.

55 changes: 39 additions & 16 deletions captum/concept/_core/cav.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
layer: str,
stats: Dict[str, Any] = None,
save_path: str = "./cav/",
model_id: str = "default_model_id",
) -> None:
r"""
This class encapsulates the instances of CAVs objects, saves them in
Expand All @@ -33,31 +34,37 @@ def __init__(
names will be saved and loaded.
layer (str): The layer where concept activation vectors are
computed using a predefined classifier.
stats (dict): a dictionary that retains information about the CAV
classifier such as CAV weights and accuracies.
stats (dict, optional): a dictionary that retains information about
the CAV classifier such as CAV weights and accuracies.
Ex.: stats = {"weights": weights, "classes": classes,
"accs": accs}, where "weights" are learned
model parameters, "classes" are a list of classes used
by the model to generate the "weights" and "accs"
the classifier training or validation accuracy.
save_path (str): The path where the CAV objects are stored.
save_path (str, optional): The path where the CAV objects are stored.
model_id (str, optional): A unique model identifier associated with
this CAV instance.
"""

self.concepts = concepts
self.layer = layer
self.stats = stats
self.save_path = save_path
self.model_id = model_id

@staticmethod
def assemble_save_path(path: str, concepts: List[Concept], layer: str):
def assemble_save_path(
path: str, model_id: str, concepts: List[Concept], layer: str
) -> str:
r"""
A utility method for assembling filename and its path, from
a concept list and a layer name.
Args:
path (str): A path to be concatenated with the concepts key and
layer name.
model_id (str): A unique model identifier associated with input
`layer` and `concepts`
concepts (list(Concept)): A list of concepts that are concatenated
together and used as a concept key using their ids. These
concept ids are retrieved from TCAV s`Concept` objects.
Expand All @@ -73,13 +80,12 @@ def assemble_save_path(path: str, concepts: List[Concept], layer: str):
layer = "inception4c"
path = "/cavs",
the resulting save path will be:
"/cavs/0-1-2-inception4c.pkl"
"/cavs/default_model_id/0-1-2-inception4c.pkl"
"""

file_name = concepts_to_str(concepts) + "-" + layer + ".pkl"

return os.path.join(path, file_name)
return os.path.join(path, model_id, file_name)

def save(self):
r"""
Expand All @@ -106,15 +112,29 @@ def save(self):
"stats": self.stats,
}

if not os.path.exists(self.save_path):
os.mkdir(self.save_path)

cavs_path = CAV.assemble_save_path(self.save_path, self.concepts, self.layer)

cavs_path = CAV.assemble_save_path(
self.save_path, self.model_id, self.concepts, self.layer
)
torch.save(save_dict, cavs_path)

@staticmethod
def load(cavs_path: str, concepts: List[Concept], layer: str):
def create_cav_dir_if_missing(save_path: str, model_id: str) -> None:
r"""
A utility function for creating the directories where the CAVs will
be stored. CAVs are saved in a folder under named by `model_id`
under `save_path`.
Args:
save_path (str): A root path where the CAVs will be stored
model_id (str): A unique model identifier associated with the
CAVs. A folder named `model_id` is created under
`save_path`. The CAVs are later stored there.
"""
cav_model_id_path = os.path.join(save_path, model_id)
if not os.path.exists(cav_model_id_path):
os.makedirs(cav_model_id_path)

@staticmethod
def load(cavs_path: str, model_id: str, concepts: List[Concept], layer: str):
r"""
Loads CAV dictionary from a pickle file for given input
`layer` and `concepts`.
Expand All @@ -123,6 +143,9 @@ def load(cavs_path: str, concepts: List[Concept], layer: str):
cavs_path (str): The root path where the cavs are stored
in the storage (on the disk).
Ex.: "/cavs"
model_id (str): A unique model identifier associated with the
CAVs. There exist a folder named `model_id` under
`cavs_path` path. The CAVs are loaded from this folder.
concepts (list[Concept]): A List of concepts for which
we would like to load the cavs.
layer (str): The layer name. Ex.: "inception4c". In case of nested
Expand All @@ -133,10 +156,10 @@ def load(cavs_path: str, concepts: List[Concept], layer: str):
cav(CAV): An instance of a CAV class, containing the respective CAV
score per concept and layer. An example of a path where the
cavs are loaded from is:
"/cavs/0-1-2-inception4c.pkl"
"/cavs/default_model_id/0-1-2-inception4c.pkl"
"""

cavs_path = CAV.assemble_save_path(cavs_path, concepts, layer)
cavs_path = CAV.assemble_save_path(cavs_path, model_id, concepts, layer)

if os.path.exists(cavs_path):
save_dict = torch.load(cavs_path)
Expand Down
4 changes: 2 additions & 2 deletions captum/concept/_core/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(
self.data_iter = data_iter

@property
def identifier(self):
def identifier(self) -> str:
return "%s-%s" % (self.name, self.id)

def __repr__(self):
def __repr__(self) -> str:
return "Concept(%r, %r)" % (self.id, self.name)


Expand Down
Loading

0 comments on commit e5a7d65

Please sign in to comment.