Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple datasets #45

Merged
merged 20 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/notebooks/cvs_DeepTICA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"dataset = create_timelagged_dataset(X,lag_time=1)\n",
"\n",
"# create datamodule\n",
"datamodule = DictionaryDataModule(dataset,lengths=[0.8,0.2])#,random_splits=False,shuffle=False)"
"datamodule = DictionaryDataModule(dataset,lengths=[0.8,0.2])#,random_split=False,shuffle=False)"
]
},
{
Expand Down Expand Up @@ -866,7 +866,7 @@
"dataset = create_timelagged_dataset(X,t,lag_time=lag_time,logweights=logweights,progress_bar=True)\n",
"\n",
"# create datamodule\n",
"datamodule = DictionaryDataModule(dataset,lengths=[0.8,0.2],random_splits=False,shuffle=False)"
"datamodule = DictionaryDataModule(dataset,lengths=[0.8,0.2],random_split=False,shuffle=False)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/cvs_TAE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
"dataset['target'] = dataset['data_lag']\n",
"\n",
"# create datamodule\n",
"datamodule = DictionaryDataModule(dataset,lengths=[0.8,0.2]) #,random_splits=False,shuffle=False)"
"datamodule = DictionaryDataModule(dataset,lengths=[0.8,0.2]) #,random_split=False,shuffle=False)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion mlcvs/core/loss/autocorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def autocorrelation_loss(
weights: Optional[torch.Tensor] = None,
invert_sign: bool = True,
) -> torch.Tensor:
"""(Weighted) autocorrelation loss.
r"""(Weighted) autocorrelation loss.

.. math::

Expand Down
5 changes: 3 additions & 2 deletions mlcvs/core/loss/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# =============================================================================

class FisherDiscriminantLoss(torch.nn.Module):
""" Fisher's discriminant ratio.
r""" Fisher's discriminant ratio.

.. math::
L = - \frac{S_b(X)}{S_w(X)}
Expand Down Expand Up @@ -61,9 +61,10 @@ def fisher_discriminant_loss(
labels: torch.Tensor,
invert_sign: bool = True
) -> torch.Tensor:
""" Fisher's discriminant ratio.
r""" Fisher's discriminant ratio.

.. math::

L = - \frac{S_b(X)}{S_w(X)}

Parameters
Expand Down
233 changes: 165 additions & 68 deletions mlcvs/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,155 @@
import torch
#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
PyTorch Lightning DataModule object for DictionaryDatasets.
"""

__all__ = ["FastDictionaryLoader"]


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Union, Sequence
import torch
from torch.utils.data import Subset
from mlcvs.data import DictionaryDataset
from mlcvs.core.transform.utils import Statistics

__all__ = ["FastDictionaryLoader"]

# =============================================================================
# FAST DICTIONARY LOADER CLASS
# =============================================================================

class FastDictionaryLoader:
"""
A DataLoader-like object for a set of tensors.
"""PyTorch DataLoader for :class:`~mlcvs.data.dataset.DictionaryDataset`s.

It is much faster than TensorDataset + DataLoader because dataloader grabs individual indices of the dataset and calls cat (slow).
It is much faster than ``TensorDataset`` + ``DataLoader`` because ``DataLoader``
grabs individual indices of the dataset and calls cat (slow).

Adapted to work with dictionaries (incl. Dictionary Dataloader).
The class can also merge multiple :class:`~mlcvs.data.dataset.DictionaryDataset`s
that have different keys (see example below). The datasets must all have the
same number of samples.

Notes
=====

Adapted from https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6.
-----

Adapted from https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6.

Examples
--------

>>> x = torch.arange(1,11)

A ``FastDictionaryLoader`` can be initialize from a ``dict``, a :class:`~mlcvs.data.dataset.DictionaryDataset`,
or a ``Subset`` wrapping a :class:`~mlcvs.data.dataset.DictionaryDataset`.

>>> # Initialize from a dictionary.
>>> d = {'data': x.unsqueeze(1), 'labels': x**2}
>>> dataloader = FastDictionaryLoader(d, batch_size=1, shuffle=False)
>>> dataloader.dataset_len # number of samples
10
>>> # Print first batch.
>>> for batch in dataloader:
... print(batch)
... break
{'data': tensor([[1]]), 'labels': tensor([1])}

>>> # Initialize from a DictionaryDataset.
>>> dict_dataset = DictionaryDataset(d)
>>> dataloader = FastDictionaryLoader(dict_dataset, batch_size=2, shuffle=False)
>>> len(dataloader) # Number of batches
5

>>> # Initialize from a PyTorch Subset object.
>>> train, _ = torch.utils.data.random_split(dict_dataset, [0.5, 0.5])
>>> dataloader = FastDictionaryLoader(train, batch_size=1, shuffle=False)

It is also possible to iterate over multiple dictionary datasets having
different keys for multi-task learning

>>> dataloader = FastDictionaryLoader(
... dataset=[dict_dataset, {'some_unlabeled_data': torch.arange(10)+11}],
... batch_size=1, shuffle=False,
... )
>>> dataloader.dataset_len # This is the number of samples in one dataset.
10
>>> # Print first batch.
>>> from pprint import pprint
>>> for batch in dataloader:
... pprint(batch)
... break
{'dataset0': {'data': tensor([[1]]), 'labels': tensor([1])},
'dataset1': {'some_unlabeled_data': tensor([11])}}

"""
def __init__(self, dataset : DictionaryDataset or dict, batch_size : int = 0, shuffle : bool = True):
"""Initialize a FastDictionaryLoader.
def __init__(
self,
dataset: Union[dict, DictionaryDataset, Subset, Sequence],
batch_size: int = 0,
shuffle: bool = True,
):
"""Initialize a ``FastDictionaryLoader``.

Parameters
----------
dataset : DictionaryDataset or dict
dataset : dict or DictionaryDataset or Subset of DictionaryDataset or list-like.
The dataset or a list of datasets. If a list, the datasets can have
different keys but they must all have the same number of samples.
batch_size : int, optional
batch size, by default 0 (==single batch)
Batch size, by default 0 (==single batch).
shuffle : bool, optional
if True, shuffle the data *in-place* whenever an
iterator is created out of this object, by default True
If ``True``, shuffle the data *in-place* whenever an
iterator is created out of this object, by default ``True``.
"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle

Returns
-------
FastDictionaryLoader
dataloader-like object
@property
def dataset(self):
"""DictionaryDataset or list[DictionaryDataset]: The dictionary dataset(s)."""
return self._dataset

"""
@dataset.setter
def dataset(self, dataset):
try:
self._dataset = _to_dict_dataset(dataset)
except ValueError:
# This is a sequence of datasets.
datasets = [_to_dict_dataset(d) for d in dataset]

# Convert to DictionaryDataset if a dict is given
if isinstance(dataset,dict):
dataset = DictionaryDataset(dataset)

# Retrieve selection if it a subset
if isinstance(dataset,Subset):
if isinstance(dataset.dataset,DictionaryDataset):
dataset = DictionaryDataset(dataset.dataset[dataset.indices])
# Check that all datasets have the same number of samples.
if len(set([len(d) for d in datasets])) != 1:
raise ValueError('All the datasets must have the same number of samples.')

# Save parameters
self.dataset = dataset
self.dataset_len = len(self.dataset)
self.batch_size = batch_size if batch_size > 0 else self.dataset_len
self.shuffle = shuffle
self._dataset = datasets

# Calculate # batches
n_batches, remainder = divmod(self.dataset_len, self.batch_size)
if remainder > 0:
n_batches += 1
self.n_batches = n_batches
@property
def dataset_len(self):
"""int: Number of samples in the dataset(s)."""
if isinstance(self._dataset, DictionaryDataset):
return len(self.dataset)
# List of datasets.
return len(self.dataset[0])

@property
def batch_size(self):
"""int: Batch size."""
return self._batch_size if self._batch_size > 0 else self.dataset_len

@batch_size.setter
def batch_size(self, batch_size):
self._batch_size = batch_size

def __iter__(self):
# Even with multiple datasets (of the same length), we generate a single
# indices permutation since these datasets are normally uncorrelated.
if self.shuffle:
self.indices = torch.randperm(self.dataset_len)
else:
Expand All @@ -70,18 +160,20 @@ def __iter__(self):
def __next__(self):
if self.i >= self.dataset_len:
raise StopIteration

if self.indices is not None:
indices = self.indices[self.i:self.i+self.batch_size]
batch = self.dataset[indices]
else:
batch = self.dataset[self.i:self.i+self.batch_size]

if isinstance(self.dataset, DictionaryDataset):
batch = self._get_batch(self.dataset)
else: # List of dict datasets.
batch = {}
for dataset_idx, dataset in enumerate(self.dataset):
batch[f'dataset{dataset_idx}'] = self._get_batch(dataset)

self.i += self.batch_size
return batch

def __len__(self):
return self.n_batches
# Number of batches.
return (self.dataset_len + self.batch_size - 1) // self.batch_size

@property
def keys(self):
Expand All @@ -92,7 +184,7 @@ def __repr__(self) -> str:
return string

def get_stats(self):
"""Compute statistics ('mean','Std','Min','Max') of the dataloader.
"""Compute statistics ``('mean','std','min','max')`` of the dataloader.

Returns
-------
Expand All @@ -115,28 +207,33 @@ def get_stats(self):

return stats

def _get_batch(self, dataset):
"""Return the current batch from the dataset."""
if self.indices is not None:
indices = self.indices[self.i:self.i+self.batch_size]
batch = dataset[indices]
else:
batch = dataset[self.i:self.i+self.batch_size]
return batch

def test_FastDictionaryLoader():
X = torch.arange(1,11).unsqueeze(1)
y = X**2

# Start from dictionary
d = {'data': X, 'labels': y}
dataloader = FastDictionaryLoader(d,batch_size=1,shuffle=False)
print(len(dataloader))
print(next(iter(dataloader)))

# or from dict dataset
dict_dataset = DictionaryDataset(d)
dataloader = FastDictionaryLoader(dict_dataset,batch_size=1,shuffle=False)
print(len(dataloader))
print(next(iter(dataloader)))

# or from subset
train, _ = torch.utils.data.random_split(dict_dataset, [0.5,0.5])
dataloader = FastDictionaryLoader(train,batch_size=1,shuffle=False)
print(len(dataloader))
print(next(iter(dataloader)))
def _to_dict_dataset(d):
"""Convert Dict[Tensor] and Subset[DictionaryDataset] to DictionaryDataset.

if __name__ == "__main__":
test_FastDictionaryLoader()
An error is raised if ``d`` cannot is of any other type.
"""
# Convert to DictionaryDataset if a dict is given.
if isinstance(d, dict):
d = DictionaryDataset(d)
elif isinstance(d, Subset) and isinstance(d.dataset, DictionaryDataset):
# TODO: This might not not safe for classes that inherit from Subset or DictionaryDatset.
# Retrieve selection if it a subset.
d = d.dataset.__class__(d.dataset[d.indices])
elif not isinstance(d, DictionaryDataset):
raise ValueError('The data must be of type dict, DictionaryDataset or Subset[DictionaryDataset].')
return d


if __name__ == '__main__':
import doctest
doctest.testmod()
Loading