Skip to content

Commit

Permalink
Merge pull request #18 from luigibonati/refactoring
Browse files Browse the repository at this point in the history
Prune the tree: move general functions to base class NNCV
  • Loading branch information
luigibonati authored Nov 16, 2022
2 parents adf73af + 82202c5 commit 053dedf
Show file tree
Hide file tree
Showing 12 changed files with 821 additions and 812 deletions.
93 changes: 61 additions & 32 deletions docs/notebooks/2d-model_discriminant.ipynb

Large diffs are not rendered by default.

161 changes: 122 additions & 39 deletions docs/notebooks/2d-model_tica.ipynb

Large diffs are not rendered by default.

56 changes: 36 additions & 20 deletions docs/notebooks/ala2_deeplda.ipynb

Large diffs are not rendered by default.

174 changes: 131 additions & 43 deletions docs/notebooks/ala2_deeptica_multithermal.ipynb

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions mlcvs/TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,24 @@

### DeepLDA_CV
- [ ] add multiclass loss function (loss function)
- [ ] add possibility to pass custom loss function (add manual someth)
- [ ] change names to private members
- [ ] add dataloader option to valid_data

### earlystopping
- [X] save model.state_dict and then load
- [X] save model.state_dict and then load

### REFACTORING
- [X] remove custom loss
- [X] prepare_dataset function
- [X] move params (e.g. train) outside
- [X] move fit to nn base
- [X] add custom_train_epoch to fit
- [x] add eval_dataset
- [x] change log to dictionary
- [x] add tests for custom_train
- [x] changed .to into set_device

### MISCELLANEA
- [ ] create dataloader from file
- [ ] evaluate dataloader function
- [ ] add option to script rather than trace jit model
223 changes: 22 additions & 201 deletions mlcvs/lda/deep_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
__all__ = ["DeepLDA_CV"]

import torch
from torch.utils.data import DataLoader,TensorDataset,random_split

from .lda import LDA
from ..models import NeuralNetworkCV
from ..utils.data import FastTensorDataLoader

class DeepLDA_CV(NeuralNetworkCV):
"""
Expand All @@ -29,7 +26,7 @@ class DeepLDA_CV(NeuralNetworkCV):
"""

def __init__(self, layers, activation="relu", device = None, **kwargs):
def __init__(self, layers, device=None, activation="relu", **kwargs):
"""
Initialize a DeepLDA_CV object
Expand All @@ -47,16 +44,12 @@ def __init__(self, layers, activation="relu", device = None, **kwargs):
self.name_ = "deeplda_cv"
self.lda = LDA()

# set device
self.device_ = device

# custom loss function
self.custom_loss = None

# lorentzian regularization
self.lorentzian_reg = 0

self.lorentzian_reg = 0
self.set_regularization(0.05)

# send model to device
self.set_device(device)

def set_regularization(self, sw_reg=0.05, lorentzian_reg=None):
"""
Expand Down Expand Up @@ -124,198 +117,26 @@ def loss_function(self, H, y, save_params=False):
loss : torch.tensor
loss function
"""
if self.custom_loss is None:

eigvals, eigvecs = self.lda.compute_LDA(H, y, save_params)
if save_params:
self.w = eigvecs

# TODO add sum option for multiclass

# if two classes loss is equal to the single eigenvalue
if self.lda.n_classes == 2:
loss = -eigvals
# if more than two classes loss equal to the smallest of the C-1 eigenvalues
elif self.lda.n_classes > 2:
loss = -eigvals[self.lda.n_classes - 2]
else:
raise ValueError("The number of classes for LDA must be greater than 1")

if self.lorentzian_reg > 0:
loss += self.regularization_lorentzian(H)
eigvals, eigvecs = self.lda.compute_LDA(H, y, save_params)
if save_params:
self.w = eigvecs

# TODO add sum option for multiclass

# if two classes loss is equal to the single eigenvalue
if self.lda.n_classes == 2:
loss = -eigvals
# if more than two classes loss equal to the smallest of the C-1 eigenvalues
elif self.lda.n_classes > 2:
loss = -eigvals[self.lda.n_classes - 2]
else:
raise ValueError("The number of classes for LDA must be greater than 1")

else:
loss = self.custom_loss(self,H,y,save_params)
if self.lorentzian_reg > 0:
loss += self.regularization_lorentzian(H)

return loss

def set_loss_function(self, func):
"""Set custom loss function
TODO document with an example
Parameters
----------
func : function
custom loss function
"""
self.custom_loss = func

def train_epoch(self, loader):
"""
Auxiliary function for training an epoch.
Parameters
----------
loader: DataLoader
training set
"""
for data in loader:
# =================get data===================
X = data[0].to(self.device_)
y = data[1].to(self.device_)
# =================forward====================
H = self.forward_nn(X)
# =================lda loss===================
loss = self.loss_function(H, y, save_params=False)
# =================backprop===================
self.opt_.zero_grad()
loss.backward()
self.opt_.step()
# ===================log======================
self.epochs += 1

def fit(
self,
train_loader=None,
valid_loader=None,
X = None,
y = None,
standardize_inputs=True,
standardize_outputs=True,
batch_size=0,
nepochs=1000,
log_every=1,
info=False,
):
"""
Train Deep-LDA CVs. Takes as input a FastTensorDataLoader/standard Dataloader constructed from a TensorDataset, or even a tuple of (colvar,labels) data.
Parameters
----------
train_data: FastTensorDataLoader/DataLoader, or tuple of torch.tensors (X:input, y:labels)
training set
valid_data: tuple of torch.tensors (X:input, y:labels) #TODO add dataloader option?
validation set
X: np.array or torch.Tensor, optional
input data, alternative to train_loader (default = None)
y: np.array or torch.Tensor, optional
labels (default = None)
standardize_inputs: bool
whether to standardize input data
standardize_outputs: bool
whether to standardize CVs
batch_size: bool, optional
number of points per batch (default = -1, single batch)
nepochs: int, optional
number of epochs (default = 1000)
log_every: int, optional
frequency of log (default = 1)
print_info: bool, optional
print debug info (default = False)
See Also
--------
loss_function
Loss functions for training Deep-LDA CVs
"""

# check optimizer
if self.opt_ is None:
self._set_default_optimizer()

# check device
if self.device_ is None:
self.device_ = next(self.nn.parameters()).device

# assert to avoid redundancy
if (train_loader is not None) and (X is not None):
raise KeyError('Only one between train_loader and X can be used.')

# create dataloader if not given
if X is not None:
if y is None:
raise KeyError('labels (y) must be given.')

if type(X) != torch.Tensor:
X = torch.Tensor(X)
if type(y) != torch.Tensor:
y = torch.Tensor(y)

dataset = TensorDataset(X,y)
train_size = int(0.9 * len(dataset))
valid_size = len(dataset) - train_size

train_data, valid_data = random_split(dataset,[train_size,valid_size])
train_loader = FastTensorDataLoader(train_data,batch_size)
valid_loader = FastTensorDataLoader(valid_data)
print('Training set:' ,len(train_data))
print('Validation set:' ,len(valid_data))

if self.lda.sw_reg == 1e-6: # default value
self.set_regularization(0.05)
print('Sw regularization:' ,self.lda.sw_reg)
print('Lorentzian reg. :' ,self.lorentzian_reg)
print('')

# standardize inputs (unravel dataset to compute average)
x_train = torch.cat([batch[0] for batch in train_loader])
if standardize_inputs:
self.standardize_inputs( x_train )

# print info
if info:
self.print_info()

# train
for ep in range(nepochs):
self.train_epoch(train_loader)

loss_train = self.evaluate_dataset(train_loader, save_params=True)
loss_valid = self.evaluate_dataset(valid_loader)
self.loss_train.append(loss_train)
self.loss_valid.append(loss_valid)

#standardize output
if standardize_outputs:
self.standardize_outputs(x_train)

# earlystopping
if self.earlystopping_ is not None:
if valid_loader is None:
raise ValueError('EarlyStopping requires validation data')
self.earlystopping_(loss_valid, model=self.state_dict() )
else:
self.set_earlystopping(patience=1e30)

# log
if ((ep + 1) % log_every == 0) or (self.earlystopping_.early_stop):
self.print_log(
{
"Epoch": ep + 1,
"Train Loss": loss_train,
"Valid Loss": loss_valid,
},
spacing=[6, 12, 12],
decimals=2,
)

# check whether to stop
if (self.earlystopping_ is not None) and (self.earlystopping_.early_stop):
self.load_state_dict( self.earlystopping_.best_model )
break


def evaluate_dataset(self, dataset, save_params=False, unravel_dataset = False):
"""
Evaluate loss function on dataset.
Expand All @@ -327,7 +148,7 @@ def evaluate_dataset(self, dataset, save_params=False, unravel_dataset = False):
save_params: bool
save the eigenvalues/vectors of LDA into the model
unravel_dataset: bool, optional
unravel dataset to calculate LDA loss on all dataset instead of averaging over batches
unravel dataset to calculate loss on all dataset instead of averaging over batches
Returns
-------
Expand Down
Loading

0 comments on commit 053dedf

Please sign in to comment.