Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activations #708

Merged
merged 14 commits into from
Sep 1, 2023
60 changes: 60 additions & 0 deletions batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from ..base import BaseModel
from ...config import Config

from collections.abc import Iterable
from collections import defaultdict



LOSSES = {
Expand Down Expand Up @@ -1847,3 +1850,60 @@ def _parse_profilers(self):
columns=['ncalls', 'CPU_tottime', 'CPU_cumtime', 'CUDA_cumtime'])
self.profile_info['CPU_tottime_avg'] = self.profile_info['CPU_tottime'] / self.profile_info['ncalls']
self.profile_info['CUDA_cumtime_avg'] = self.profile_info['CUDA_cumtime'] / self.profile_info['ncalls']


# Utilities for activations
@staticmethod
def get_blocks_and_activations(model, modules=None):
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
""" Retrieve intermediate blocks of the neural network model
and corresponding activation names.
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
model : Network
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
modules : str or list of str, default None
The main parts of the model for retrieving activations from.
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
If None, all of the parts from the model.config['order'] wil be used.
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
activation_blocks : list of str
List with activation blocks which will be passed into `outputs` parameter of `self.predict`.
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
activation_names : list of str
List with names of activations corresponding to activation blocks.
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
>>> get_blocks_and_activations(model, modules=['encoder', 'embedding', 'decoder'])
activation_blocks = ['model.encoder["block-0"]', 'model.encoder["block-1"]', 'model.embedding', 'model.decoder["block-0"]']
activation_names = ['encoder_0', 'encoder_1', 'embedding_0', 'decoder_0']
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
"""
extracted_blocks = defaultdict(list)
modules = [modules] if isinstance(modules, str) else modules or model.config['order']

for module in modules:
extracted_module = getattr(model, module)
if isinstance(extracted_module, Iterable):
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
for block in extracted_module:
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
if 'block' in block:
extracted_blocks[module].append(f'model.{module}["{block}"]')
else:
extracted_blocks[module].append(f'model.{module}')

activation_blocks = []
activation_names = []
for module, blocks in extracted_blocks.items():
activation_blocks.extend(blocks)
activation_names += [f'{module}_{i}' for i in range(len(blocks))]
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved

return activation_blocks, activation_names
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you even need activation_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you even need activation_names?

I use them to assign activations to batch, as well as to provide titles for plots.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is how you use them. Why something related to your code (batch and plots), that don't belong here, should be there?


@staticmethod
def reduce_channels(pca_instance, images, normalize=True):
""" Convert multichannel 'b c h w' images to RGB images using PCA. """
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
images = images.transpose(0, 2, 3, 1)
compressed_images = pca_instance.fit_transform(images.reshape(-1, images.shape[-1]))
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
compressed_images = compressed_images.reshape(*images.shape[:3], pca_instance.n_components)
if normalize:
compressed_images = (compressed_images - compressed_images.min()) / (compressed_images.max() - compressed_images.min())
EvgeniyS99 marked this conversation as resolved.
Show resolved Hide resolved
return compressed_images, pca_instance.explained_variance_ratio_
Loading