Skip to content

Commit

Permalink
Model mode and inits
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyTsimfer committed Aug 1, 2024
1 parent e961d31 commit 9bb89eb
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,13 @@ class TorchModel(BaseModel, ExtractionMixin, OptimalBatchSizeMixin, Visualizatio
trainable : sequence, optional
Names of model parts to train. Should be a subset of names in `order` and can be used to freeze parameters.
init_weights : callable, 'best_practice_resnet', or None
init_weights : callable, 'best_practice_resnet', tuple, sequence of them or None
Model weights initialization.
If None, then default initialization is used.
If 'best_practice_resnet', then common used non-default initialization is used.
If callable, then callable applied to each layer.
If tuple, then the first element should be of the types above, and the second defines a model part to apply on.
If sequence, then each element should be of the types defined above: applies all init functions sequentially.
Examples:
Expand All @@ -212,6 +214,7 @@ def callable_init(module): # example of a callable for init
nn.kaiming_normal_(module.weight)
config = {'init_weights': callable_init}
- ``{'init_weights': ('best_practice_resnet', 'body')}`` # applies only at `body` module
# Shapes: optional
Expand Down Expand Up @@ -854,7 +857,8 @@ def build_model(self, inputs=None):
inputs = self.make_placeholder_data(to_device=True)

if 'model' not in self.config:
self.model = Network(inputs=inputs, config=self.config, device=self.device)
with torch.no_grad():
self.model = Network(inputs=inputs, config=self.config, device=self.device)
else:
self.model = self.config['model']

Expand Down Expand Up @@ -902,12 +906,20 @@ def initialize_weights(self):
# Parse model weights initialization
init_weights = init_weights if isinstance(init_weights, list) else [init_weights]

for init_weights_function in init_weights:
for init_weights_ in init_weights:
if isinstance(init_weights_, tuple) and len(init_weights_) == 2:
init_weights_function, init_weights_module = init_weights_
else:
init_weights_function, init_weights_module = init_weights_, None

if init_weights_function in {'resnet', 'classic'}:
init_weights_function = best_practice_resnet_init

# Actual weights initialization
self.model.apply(init_weights_function)
if init_weights_module is None:
self.model.apply(init_weights_function)
else:
getattr(self.model, init_weights_module).apply(init_weights_function)


# Transfer to/from device(s)
Expand Down Expand Up @@ -1010,6 +1022,10 @@ def train(self, inputs, targets, outputs=None, mode='train', lock=True, profile=
with the same keys and requested tensors as values.
lock : bool
If True, then model, loss and gradient update operations are locked, thus allowing for multithreading.
mode : None, str or callable
If None, then does nothing.
If str, then identifies mode to put the model in: one of ``'train'`` or ``'eval'``.
If callable, then applied to the model directly.
sync_frequency : int, bool or None
If int, then how often to apply accumulated gradients to the weights.
If True, then value from config is used.
Expand Down Expand Up @@ -1336,6 +1352,10 @@ def predict(self, inputs, targets=None, outputs=None, lock=True, microbatch_size
amp : None or bool
If None, then use amp setting from config.
If bool, then overrides the amp setting for prediction.
mode : None, str or callable
If None, then does nothing.
If str, then identifies mode to put the model in: one of ``'train'`` or ``'eval'``.
If callable, then applied to the model directly.
no_grad : bool
Whether to disable gradient computation during model evaluation.
transfer_from_device : bool
Expand Down Expand Up @@ -1476,11 +1496,15 @@ def __call__(self, inputs, targets=None, outputs='predictions', lock=True,

# Common utilities for train and predict
def set_model_mode(self, mode):
""" Set model mode to either train or eval. """
""" Set model mode to either train or eval. If provided with a callable, applies it to the model directly. """
if mode in {'train', 'training'}:
self.model.train()
elif mode in {'eval', 'predict', 'inference'}:
self.model.eval()
elif mode is None:
pass
elif callable(mode):
self.model.apply(mode)
else:
raise ValueError(f'Unknown model mode={mode}')

Expand Down

0 comments on commit 9bb89eb

Please sign in to comment.