Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jgrss/test dataset #33

Merged
merged 21 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/cultionet/data/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ def close_edge_ends(labels_array: np.ndarray) -> np.ndarray:


def is_grid_processed(
process_path: Path, transforms: T.List[str], group_id: str, grid: T.Union[str, int], n_ts: int
process_path: Path,
transforms: T.List[str],
group_id: str,
grid: T.Union[str, int],
n_ts: int
) -> bool:
"""Checks if a grid is already processed
"""
Expand Down Expand Up @@ -538,7 +542,9 @@ def create_dataset(

# Open the projected land cover
with gw.config.update(
ref_bounds=df_latlon.total_bounds.tolist(), ref_crs=ref_crs, ref_res=ref_res
ref_bounds=df_latlon.total_bounds.tolist(),
ref_crs=ref_crs,
ref_res=ref_res
):
with gw.open(lc_path, chunks=2048) as src:
lc_labels = src.squeeze()[:labels_array.shape[0], :labels_array.shape[1]].data.compute()
Expand All @@ -560,9 +566,12 @@ def create_dataset(
props = regionprops(segments)

ldata = LabeledData(
x=xvars, y=labels_array, bdist=bdist, segments=segments, props=props
x=xvars,
y=labels_array,
bdist=bdist,
segments=segments,
props=props
)

def save_and_update(train_data: Data) -> None:
train_path = process_path / f'data_{train_data.train_id}.pt'
torch.save(train_data, train_path)
Expand Down
78 changes: 69 additions & 9 deletions src/cultionet/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,25 @@
from pathlib import Path
import random
import logging
from functools import partial

from ..errors import TensorShapeError
from ..utils.logging import set_color_logger

import numpy as np
import attr
import torch
from torch_geometric.data import Data, Dataset
import psutil
from tqdm.auto import tqdm
from joblib import Parallel, delayed, parallel_backend

ATTRVINSTANCE = attr.validators.instance_of
ATTRVIN = attr.validators.in_
ATTRVOPTIONAL = attr.validators.optional

logger = set_color_logger(__name__)


def add_dims(d: torch.Tensor) -> torch.Tensor:
return d.unsqueeze(0)
Expand All @@ -32,14 +40,36 @@ def zscores(

z = (x - μ) / σ
"""
x = torch.cat([
(batch.x[:, :-1] - add_dims(data_means)) / add_dims(data_stds),
batch.x[:, -1][:, None]
], dim=1)
x = ((batch.x - add_dims(data_means)) / add_dims(data_stds))

return Data(x=x, **{k: getattr(batch, k) for k in batch.keys if k != 'x'})


def _check_shape(d1: tuple, d2: tuple, index: int, uid: str) -> T.Tuple[bool, int, str]:
if d1 != d2:
return False, index, uid
return True, index, uid


class TqdmParallel(Parallel):
"""A tqdm progress bar for joblib Parallel tasks

Reference:
https://stackoverflow.com/questions/37804279/how-can-we-use-tqdm-in-a-parallel-execution-with-joblib
"""
def __init__(self, tqdm_kwargs: dict):
self.tqdm_kwargs = tqdm_kwargs
super().__init__()

def __call__(self, *args, **kwargs):
with tqdm(**self.tqdm_kwargs) as self._pbar:
return Parallel.__call__(self, *args, **kwargs)

def print_progress(self):
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()


@attr.s
class EdgeDataset(Dataset):
"""An edge dataset
Expand All @@ -50,6 +80,8 @@ class EdgeDataset(Dataset):
data_means: T.Optional[torch.Tensor] = attr.ib(validator=ATTRVOPTIONAL(ATTRVINSTANCE(torch.Tensor)), default=None)
data_stds: T.Optional[torch.Tensor] = attr.ib(validator=ATTRVOPTIONAL(ATTRVINSTANCE(torch.Tensor)), default=None)
pattern: T.Optional[str] = attr.ib(validator=ATTRVOPTIONAL(ATTRVINSTANCE(str)), default='data*.pt')
processes: T.Optional[int] = attr.ib(validator=ATTRVOPTIONAL(ATTRVINSTANCE(int)), default=psutil.cpu_count())
threads_per_worker: T.Optional[int] = attr.ib(validator=ATTRVOPTIONAL(ATTRVINSTANCE(int)), default=1)

data_list_ = None

Expand Down Expand Up @@ -93,13 +125,41 @@ def processed_file_names(self):
"""Get a list of processed files"""
return self.data_list_

def check_dims(self):
def check_dims(self, delete_mismatches: bool = False, tqdm_color: str = 'ffffff'):
"""Checks if all tensors in the dataset match in shape dimensions
"""
ref_dim = self[0].x.shape
for i in range(1, len(self)):
if self[i].x.shape != ref_dim:
raise TensorShapeError(f'{Path(self.data_list_[i]).name} does not match the reference.')
ref_dim = tuple(self[0].x.shape)
check_partial = partial(_check_shape, ref_dim)

with parallel_backend(
backend='loky',
n_jobs=self.processes,
inner_max_num_threads=self.threads_per_worker
):
with TqdmParallel(
tqdm_kwargs={
'total': len(self),
'desc': 'Checking dimensions',
'colour': tqdm_color
}
) as pool:
results = pool(
delayed(check_partial)(
tuple(self[i].x.shape), i, self[i].train_id
) for i in range(1, len(self))
)
matches, indices, ids = list(map(list, zip(*results)))
if not all(matches):
null_indices = np.array(indices)[~np.array(matches)]
null_ids = np.array(ids)[null_indices].tolist()
logger.warning(','.join(null_ids))
logger.warning(f'{null_indices.shape[0]:,d} ids did not match the reference dimensions.')

if delete_mismatches:
logger.warning(f'Removing {null_indices.shape[0]:,d} .pt files.')
[self.data_list_[i].unlink() for i in null_indices]
else:
raise TensorShapeError

def len(self):
"""Returns the dataset length"""
Expand Down
10 changes: 8 additions & 2 deletions src/cultionet/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,11 @@ class TopologyClipError(Exception):


class TensorShapeError(Exception):
"""Raised when tensor shapes do not match"""
pass
"""Raised when tensor shapes do not match
"""
def __init__(self, message: str = 'The tensor shapes do not match.') -> None:
self.message = message
super().__init__(self.message)

def __str__(self):
return self.message
77 changes: 53 additions & 24 deletions src/cultionet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from torch_geometric import seed_everything
from torch_geometric.data import Data
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks import (
ModelCheckpoint,
LearningRateMonitor,
StochasticWeightAveraging,
ModelPruning
)
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

logging.getLogger('lightning').addHandler(logging.NullHandler())
Expand All @@ -26,6 +31,7 @@
def fit(
dataset: EdgeDataset,
ckpt_file: T.Union[str, Path],
test_dataset: T.Optional[EdgeDataset] = None,
val_frac: T.Optional[float] = 0.2,
batch_size: T.Optional[int] = 4,
accumulate_grad_batches: T.Optional[int] = 1,
Expand All @@ -40,14 +46,18 @@ def fit(
reset_model: T.Optional[bool] = False,
auto_lr_find: T.Optional[bool] = False,
device: T.Optional[str] = 'gpu',
stochastic_weight_avg: T.Optional[bool] = False,
weight_decay: T.Optional[float] = 1e-5,
precision: T.Optional[int] = 32,
stochastic_weight_averaging: T.Optional[bool] = False,
model_pruning: T.Optional[bool] = False
):
"""Fits a model

Args:
dataset (EdgeDataset): The dataset to fit on.
ckpt_file (str | Path): The checkpoint file path.
test_dataset (Optional[EdgeDataset]): A test dataset to evaluate on. If given, early stopping
will switch from the validation dataset to the test dataset.
val_frac (Optional[float]): The fraction of data to use for model validation.
batch_size (Optional[int]): The data batch size.
filters (Optional[int]): The number of initial model filters.
Expand All @@ -62,6 +72,11 @@ def fit(
an existing model.
auto_lr_find (Optional[bool]): Whether to search for an optimized learning rate.
device (Optional[str]): The device to train on. Choices are ['cpu', 'gpu'].
weight_decay (Optional[float]): The weight decay passed to the optimizer. Default is 1e-5.
precision (Optional[int]): The data precision. Default is 32.
stochastic_weight_averaging (Optional[bool]): Whether to use stochastic weight averaging.
Default is False.
model_pruning (Optional[bool]): Whether to prune the model. Default is False.
"""
ckpt_file = Path(ckpt_file)

Expand All @@ -73,6 +88,7 @@ def fit(
data_module = EdgeDataModule(
train_ds=train_ds,
val_ds=val_ds,
test_ds=test_dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True
Expand All @@ -96,7 +112,7 @@ def fit(
else:
ckpt_path = None

# Callbacks
# Checkpoint
cb_train_loss = ModelCheckpoint(
dirpath=ckpt_file.parent,
filename=ckpt_file.name,
Expand All @@ -107,27 +123,34 @@ def fit(
every_n_train_steps=0,
every_n_epochs=1
)

# Validation and test loss
cb_val_loss = ModelCheckpoint(monitor='val_loss')

# Early stopping
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=early_stopping_min_delta,
patience=early_stopping_patience,
mode='min',
check_on_train_epoch_end=False
)

# Learning rate
lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks = [
lr_monitor,
cb_train_loss,
cb_val_loss,
early_stop_callback
]
if stochastic_weight_averaging:
callbacks.append(StochasticWeightAveraging(swa_lrs=learning_rate))
if 0 < model_pruning <= 1:
callbacks.append(
ModelPruning('l1_unstructured', amount=model_pruning)
)

trainer = pl.Trainer(
default_root_dir=str(ckpt_file.parent),
callbacks=[
lr_monitor,
cb_train_loss,
cb_val_loss,
early_stop_callback
],
callbacks=callbacks,
enable_checkpointing=True,
auto_lr_find=auto_lr_find,
auto_scale_batch_size=False,
Expand All @@ -137,25 +160,33 @@ def fit(
check_val_every_n_epoch=1,
min_epochs=5 if epochs >= 5 else epochs,
max_epochs=epochs,
precision=32,
devices=1 if device == 'gpu' else 0,
gpus=1 if device == 'gpu' else 0,
precision=precision,
devices=1 if device == 'gpu' else None,
gpus=1 if device == 'gpu' else None,
num_processes=0,
accelerator=device,
log_every_n_steps=10
log_every_n_steps=10,
profiler=None
)

if auto_lr_find:
trainer.tune(model=lit_model, datamodule=data_module)
else:
trainer.fit(model=lit_model, datamodule=data_module, ckpt_path=ckpt_path)
trainer.fit(
model=lit_model,
datamodule=data_module,
ckpt_path=ckpt_path
)
if test_dataset is not None:
trainer.test(
model=lit_model,
dataloaders=data_module.test_dataloader(),
ckpt_path='last'
)


def load_model(
num_features: int,
num_time_features: int,
ckpt_file: T.Union[str, Path],
filters: T.Optional[int] = 32,
device: T.Union[str, bytes] = 'gpu',
lit_model: T.Optional[CultioLitModel] = None,
enable_progress_bar: T.Optional[bool] = True
Expand All @@ -164,9 +195,7 @@ def load_model(

Args:
ckpt_file (str | Path): The model checkpoint file.
filters (int): The model input filters.
device (str): The device to apply inference on.
trainer (pl.Trainer): The `pytorch_lightning` trainer.
lit_model (CultioLitModel): A model to predict with. If `None`, the model
is loaded from file.
enable_progress_bar (Optional[bool]): Whether to use the progress bar.
Expand All @@ -176,8 +205,8 @@ def load_model(
trainer_kwargs = dict(
default_root_dir=str(ckpt_file.parent),
precision=32,
devices=1 if device == 'gpu' else 0,
gpus=1 if device == 'gpu' else 0,
devices=1 if device == 'gpu' else None,
gpus=1 if device == 'gpu' else None,
accelerator=device,
num_processes=0,
log_every_n_steps=0,
Expand Down
6 changes: 3 additions & 3 deletions src/cultionet/models/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def training_step(self, batch: Data, batch_idx: int = None):
def _shared_eval_step(self, batch: Data, batch_idx: int = None) -> dict:
loss = self.calc_loss(batch)

__, edge_ypred, class_ypred, class_ypred_r = self.predict_labels(batch)
__, edge_ypred, __, class_ypred_r = self.predict_labels(batch)

# F1-score
edge_score = self.scorer(edge_ypred, batch.y.eq(self.edge_value).long())
Expand Down Expand Up @@ -224,8 +224,8 @@ def test_step(self, batch: Data, batch_idx: int = None) -> dict:

metrics = {
'test_loss': eval_metrics['loss'],
'crop_r_loss': eval_metrics['crop_r_loss'],
'tf1': eval_metrics['class_score'],
'tef1': eval_metrics['edge_score'],
'tcf1': eval_metrics['class_score'],
'temcc': eval_metrics['emcc'],
'tcmcc': eval_metrics['cmcc']
}
Expand Down
Loading