-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- 🚧 add code for tinycd model - 📝 Document TinyCD modules - Add docstrings referencing the original TinyCD code and ArXiV paper. Updated the chabud/README.md file to mention what the layers.py, tinycd_model.py are for, and changed some emojis too. - 🔊 Log loss and metrics to the terminal & wandb properly using `CSVLogger` & `WandBLogger` - Ignore the california_*.hdf5 files while training as they don't have any burned areas - Add a Unet model for reference to compare with TinyCD - Add `batchnorm` as first layer to normalize & change `pos_weight` to 5.0 - Move the trainer outside of CLI for quick experiments - Add callback to log intermediate predictions --------- Co-authored-by: SRM <soumya@developmentseed.org> Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>
- Loading branch information
1 parent
0f9f876
commit 9e19af1
Showing
13 changed files
with
761 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from lightning.pytorch.callbacks import Callback | ||
import torch | ||
import torch.nn.functional as F | ||
import wandb | ||
|
||
|
||
class LogIntermediatePredictions(Callback): | ||
"""Visualize the model results at the end of every epoch.""" | ||
|
||
def __init__(self, logger): | ||
"""Instantiates with wandb-logger. | ||
Args: | ||
logger : wandb-logger instance. | ||
""" | ||
super().__init__() | ||
self.logger = logger | ||
|
||
def on_validation_batch_end( | ||
self, | ||
trainer, | ||
pl_module, | ||
outputs, | ||
batch, | ||
batch_idx, | ||
dataloader_idx=0, | ||
): | ||
"""Called when the validation batch ends. | ||
At the end of each epoch, takes a sample from validation dataset & logs | ||
the image with model predictions to wandb-logger for humans to interpret | ||
how model evolves over time. | ||
""" | ||
if batch_idx == 0: | ||
# Take a small sample size for logging | ||
id2label = {0: "ok", 1: "burn"} | ||
log_list = [] | ||
|
||
with torch.no_grad(): | ||
pre_img, post_img, mask, metadata = batch | ||
batch_size = mask.shape[0] | ||
|
||
# Pass the image through neural network model to get predicted images | ||
logits: torch.Tensor = pl_module(x1=pre_img, x2=post_img).squeeze() | ||
y_pred: torch.Tensor = F.sigmoid(logits) | ||
y_pred = (y_pred > 0.5).int().detach().cpu().numpy() | ||
|
||
for i in range(batch_size): | ||
log_image = wandb.Image( | ||
post_img[i].permute(1, 2, 0).detach().cpu().numpy() / 6000, | ||
masks={ | ||
"prediction": { | ||
"mask_data": mask[i].detach().cpu().numpy(), | ||
"class_labels": id2label, | ||
}, | ||
"ground_truth": { | ||
"mask_data": y_pred[i], | ||
"class_labels": id2label, | ||
}, | ||
}, | ||
) | ||
log_list.append(log_image) | ||
|
||
wandb.log({"predictions": log_list}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
Modular block layers of the TinyCD model. | ||
Reference: | ||
- https://github.com/AndreaCodegoni/Tiny_model_4_CD/blob/main/models/layers.py | ||
- Codegoni, A., Lombardi, G., & Ferrari, A. (2022). TINYCD: A (Not So) Deep | ||
Learning Model For Change Detection (arXiv:2207.13159). arXiv. | ||
https://doi.org/10.48550/arXiv.2207.13159 | ||
""" | ||
from typing import List, Optional | ||
|
||
from torch import Tensor, reshape, stack | ||
from torch.nn import Conv2d, InstanceNorm2d, Module, PReLU, Sequential, Upsample | ||
|
||
|
||
class PixelwiseLinear(Module): | ||
def __init__( | ||
self, | ||
fin: List[int], | ||
fout: List[int], | ||
last_activation: Module = None, | ||
) -> None: | ||
assert len(fout) == len(fin) | ||
super().__init__() | ||
|
||
n = len(fin) | ||
self._linears = Sequential( | ||
*[ | ||
Sequential( | ||
Conv2d(fin[i], fout[i], kernel_size=1, bias=True), | ||
PReLU() | ||
if i < n - 1 or last_activation is None | ||
else last_activation, | ||
) | ||
for i in range(n) | ||
] | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
# Processing the tensor: | ||
return self._linears(x) | ||
|
||
|
||
class MixingBlock(Module): | ||
def __init__( | ||
self, | ||
ch_in: int, | ||
ch_out: int, | ||
): | ||
super().__init__() | ||
self._convmix = Sequential( | ||
Conv2d(ch_in, ch_out, 3, groups=ch_out, padding=1), | ||
PReLU(), | ||
InstanceNorm2d(ch_out), | ||
) | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
# Packing the tensors and interleaving the channels: | ||
mixed = stack((x, y), dim=2) | ||
mixed = reshape(mixed, (x.shape[0], -1, x.shape[2], x.shape[3])) | ||
|
||
# Mixing: | ||
return self._convmix(mixed) | ||
|
||
|
||
class MixingMaskAttentionBlock(Module): | ||
"""use the grouped convolution to make a sort of attention""" | ||
|
||
def __init__( | ||
self, | ||
ch_in: int, | ||
ch_out: int, | ||
fin: List[int], | ||
fout: List[int], | ||
generate_masked: bool = False, | ||
): | ||
super().__init__() | ||
self._mixing = MixingBlock(ch_in, ch_out) | ||
self._linear = PixelwiseLinear(fin, fout) | ||
self._final_normalization = InstanceNorm2d(ch_out) if generate_masked else None | ||
self._mixing_out = MixingBlock(ch_in, ch_out) if generate_masked else None | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
z_mix = self._mixing(x, y) | ||
z = self._linear(z_mix) | ||
z_mix_out = 0 if self._mixing_out is None else self._mixing_out(x, y) | ||
|
||
return ( | ||
z | ||
if self._final_normalization is None | ||
else self._final_normalization(z_mix_out * z) | ||
) | ||
|
||
|
||
class UpMask(Module): | ||
def __init__( | ||
self, | ||
scale_factor: float, | ||
nin: int, | ||
nout: int, | ||
): | ||
super().__init__() | ||
self._upsample = Upsample( | ||
scale_factor=scale_factor, mode="bilinear", align_corners=True | ||
) | ||
self._convolution = Sequential( | ||
Conv2d(nin, nin, 3, 1, groups=nin, padding=1), | ||
PReLU(), | ||
InstanceNorm2d(nin), | ||
Conv2d(nin, nout, kernel_size=1, stride=1), | ||
PReLU(), | ||
InstanceNorm2d(nout), | ||
) | ||
|
||
def forward(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor: | ||
x = self._upsample(x) | ||
if y is not None: | ||
x = x * y | ||
return self._convolution(x) |
Oops, something went wrong.