Skip to content

Commit

Permalink
Add functions for activations
Browse files Browse the repository at this point in the history
  • Loading branch information
EvgeniyS99 committed Jul 25, 2023
1 parent a1dfd78 commit f95ade3
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions batchflow/models/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch import nn
import torch.nn.functional as F

from sklearn.decomposition import PCA


def to_n_tuple(value, n):
return value if isinstance(value, (tuple, list)) else (value,) * n
Expand Down Expand Up @@ -118,3 +120,49 @@ def center_crop(inputs, target_shape, dims):
crops_ = [slice(crop_lefts[i], crop_lefts[i] + target_shape[i]) for i in range(dims)]
crops = [slice(dim) for dim in no_crop_shape] + crops_
return inputs[crops]


def get_blocks_and_activations(model):
""" Retrieve intermediate blocks of the neural netowork model
and corresponding activation names.
"""
encoder_blocks = list(filter(lambda x: 'block' in x, model.model.encoder))
decoder_blocks = list(filter(lambda x: 'block' in x, model.model.decoder))

blocks = [f'model.encoder["{block}"]' for block in encoder_blocks]
embedding = 'embedding' in model.config['order']
if embedding:
blocks += ['model.embedding']
blocks += [f'model.decoder["{block}"]' for block in decoder_blocks]

activation_names = [f'encoder_{i}' for i in range(len(encoder_blocks))]
if embedding:
activation_names += ['embedding']
activation_names += [f'decoder_{i}' for i in range(len(decoder_blocks))]

return blocks, activation_names

def compress_activations(batch, activation_names, **kwargs):
""" Apply PCA channel reduction to intermidiate activations of the neural network model
and assign compressed images to the batch's attributes
"""
for activation_name in activation_names:
activation_images = getattr(batch, activation_name).copy()
if not np.isnan(activation_images.min()):
compressed_images, var = reduce_channels(activation_images, **kwargs)
setattr(batch, activation_name, compressed_images)
else:
return None

return batch, var

def reduce_channels(images, n_components=3, **kwargs):
""" Convert multichannel 'b c h w' images from neural network model to RGB images """
_ = kwargs
images = images.transpose(0, 2, 3, 1)
pca_instance = PCA(n_components=n_components)
compressed_images = pca_instance.fit_transform(images.reshape(-1, images.shape[-1]))
compressed_images = compressed_images.reshape(*images.shape[:3], n_components)
compressed_images = (compressed_images - compressed_images.min()) / (compressed_images.max() - compressed_images.min())

return compressed_images, pca_instance.explained_variance_ratio_

0 comments on commit f95ade3

Please sign in to comment.