Skip to content

Commit

Permalink
Segmentation on Clay (#257)
Browse files Browse the repository at this point in the history
- Add segmentation support for Clay
- Add preprocessing script for Chesapeake Bay dataset
  • Loading branch information
srmsoumya committed Jun 5, 2024
1 parent 54ae397 commit 85e821c
Show file tree
Hide file tree
Showing 7 changed files with 936 additions and 0 deletions.
58 changes: 58 additions & 0 deletions configs/segment_chesapeake.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# lightning.pytorch==2.1.2
seed_everything: 42
data:
train_chip_dir: data/cvpr/ny/train/chips/
train_label_dir: data/cvpr/ny/train/labels/
val_chip_dir: data/cvpr/ny/val/chips/
val_label_dir: data/cvpr/ny/val/labels/
metadata_path: configs/metadata.yaml
batch_size: 40
num_workers: 8
platform: naip
model:
num_classes: 7
feature_maps:
- 3
- 5
- 7
- 11
ckpt_path: checkpoints/clay-v1-base.ckpt
lr: 1e-5
wd: 0.05
b1: 0.9
b2: 0.95
trainer:
accelerator: auto
strategy: ddp
devices: auto
num_nodes: 1
precision: bf16-mixed
log_every_n_steps: 5
max_epochs: 10
accumulate_grad_batches: 1
default_root_dir: checkpoints/segment
fast_dev_run: False
num_sanity_val_steps: 0
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
entity: developmentseed
project: clay-segment
log_model: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: checkpoints/segment
auto_insert_metric_name: False
filename: chesapeake-7class-segment_epoch-{epoch:02d}_val-iou-{val/iou:.4f}
monitor: val/iou
mode: max
save_last: True
save_top_k: 2
save_weights_only: True
verbose: True
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
plugins:
- class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO
124 changes: 124 additions & 0 deletions finetune/segment/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Segmentor

The `Segmentor` class is designed for semantic segmentation tasks, extracting feature maps from intermediate layers of the Clay Encoder and adding a Feature Pyramid Network (FPN) on top of it.

Decoder is inspired by the Segformer paper.
Todo:
- Add neck & head for segmentation task from other papers like UperNet, PPANet, etc. to compare with other GeoAI models.


## Parameters

- `feature_maps (list)`: Indices of intermediate layers of the Clay Encoder used by FPN layers.
- `ckpt_path (str)`: Path to the Clay model checkpoint.

## Example

In this example, we will use the `Segmentor` class to segment Land Use Land Cover (LULC) classes for the Chesapeake Bay CVPR dataset. The implementation includes data preprocessing, data loading, and model training workflow using PyTorch Lightning.

## Dataset

### Citation

If you use this dataset, please cite the associated manuscript:

Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N.
Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data.
Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019).

Dataset URL: [Chesapeake Bay Land Cover Dataset](https://lila.science/datasets/chesapeakelandcover)

## Setup

Follow the instructions in the [README](../../README.md) to install the required dependencies.

```bash
git clone <repo-url>
cd model
mamba env create --file environment.yml
mamba activate claymodel
```

## Usage

### Preparing the Dataset

Download the Chesapeake Bay Land Cover dataset and organize your dataset directory as recommended.

1. Copy `*_lc.tif` and `*_naip-new.tif` files for segmentation downstream tasks using s5cmd:
```bash
# train
s5cmd cp --include "*_lc.tif" --include "*_naip-new.tif" "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" data/cvpr/files/train/

# val
s5cmd cp --include "*_lc.tif" --include "*_naip-new.tif" "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" data/cvpr/files/val/
```

2. Create chips of size `224 x 224` to feed them to the model:
```bash
python finetune/segment/preprocess_data.py data/cvpr/files data/cvpr/ny 224
```

Directory structure:
```
data/
└── cvpr/
├── files/
│ ├── train/
│ └── val/
└── ny/
├── train/
│ ├── chips/
│ └── labels/
└── val/
├── chips/
└── labels/
```
### Training the Model
The model can be run via LightningCLI using configurations in `finetune/segment/configs/segment_chesapeake.yaml`.
1. Download the Clay model checkpoint from [Huggingface model hub](https://huggingface.co/made-with-clay/Clay/blob/main/clay-v1-base.ckpt) and save it in the `checkpoints/` directory.
2. Modify the batch size, learning rate, and other hyperparameters in the configuration file as needed:
```yaml
data:
batch_size: 40
num_workers: 8
model:
num_classes: 7
feature_maps:
- 3
- 5
- 7
- 11
ckpt_path: checkpoints/clay-v1-base.ckpt
lr: 1e-5
wd: 0.05
b1: 0.9
b2: 0.95
```

3. Update the [WandB logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html#lightning.pytorch.loggers.WandbLogger) configuration in the configuration file with your WandB details or use [CSV Logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.CSVLogger.html#lightning.pytorch.loggers.CSVLogger) if you don't want to log to WandB:
```yaml
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
entity: <wandb-entity>
project: <wandb-project>
log_model: false
```
4. Train the model:
```bash
python segment.py fit --config configs/segment_chesapeake.yaml
```

## Acknowledgments

Decoder implementation is inspired by the Segformer paper:
```
Segformer: Simple and Efficient Design for Semantic Segmentation with Transformers
Paper URL: https://arxiv.org/abs/2105.15203
```
187 changes: 187 additions & 0 deletions finetune/segment/chesapeake_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""
DataModule for the Chesapeake Bay dataset for segmentation tasks.
This implementation provides a structured way to handle the data loading and
preprocessing required for training and validating a segmentation model.
Dataset citation:
Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N.
Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data.
Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition
(CVPR 2019).
Dataset URL: https://lila.science/datasets/chesapeakelandcover
"""

import re
from pathlib import Path

import lightning as L
import numpy as np
import torch
import yaml
from box import Box
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2


class ChesapeakeDataset(Dataset):
"""
Dataset class for the Chesapeake Bay segmentation dataset.
Args:
chip_dir (str): Directory containing the image chips.
label_dir (str): Directory containing the labels.
metadata (Box): Metadata for normalization and other dataset-specific details.
platform (str): Platform identifier used in metadata.
"""

def __init__(self, chip_dir, label_dir, metadata, platform):
self.chip_dir = Path(chip_dir)
self.label_dir = Path(label_dir)
self.metadata = metadata
self.transform = self.create_transforms(
mean=list(metadata[platform].bands.mean.values()),
std=list(metadata[platform].bands.std.values()),
)

# Load chip and label file names
self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")]
self.labels = [re.sub("_naip-new_", "_lc_", chip) for chip in self.chips]

def create_transforms(self, mean, std):
"""
Create normalization transforms.
Args:
mean (list): Mean values for normalization.
std (list): Standard deviation values for normalization.
Returns:
torchvision.transforms.Compose: A composition of transforms.
"""
return v2.Compose(
[
v2.Normalize(mean=mean, std=std),
],
)

def __len__(self):
return len(self.chips)

def __getitem__(self, idx):
"""
Get a sample from the dataset.
Args:
idx (int): Index of the sample.
Returns:
dict: A dictionary containing the image, label, and additional information.
"""
chip_name = self.chip_dir / self.chips[idx]
label_name = self.label_dir / self.labels[idx]

chip = np.load(chip_name).astype(np.float32)
label = np.load(label_name)

# Remap labels to match desired classes
label_mapping = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 15: 6}
remapped_label = np.vectorize(label_mapping.get)(label)

# Apply transformations
if self.transform:
chip = self.transform(torch.from_numpy(chip))

sample = {
"pixels": self.transform(torch.from_numpy(chip)),
"label": torch.from_numpy(remapped_label[0]),
"time": torch.zeros(4), # Placeholder for time information
"latlon": torch.zeros(4), # Placeholder for latlon information
}
return sample


class ChesapeakeDataModule(L.LightningDataModule):
"""
DataModule class for the Chesapeake Bay dataset.
Args:
train_chip_dir (str): Directory containing training image chips.
train_label_dir (str): Directory containing training labels.
val_chip_dir (str): Directory containing validation image chips.
val_label_dir (str): Directory containing validation labels.
metadata_path (str): Path to the metadata file.
batch_size (int): Batch size for data loading.
num_workers (int): Number of workers for data loading.
platform (str): Platform identifier used in metadata.
"""

def __init__( # noqa: PLR0913
self,
train_chip_dir,
train_label_dir,
val_chip_dir,
val_label_dir,
metadata_path,
batch_size,
num_workers,
platform,
):
super().__init__()
self.train_chip_dir = train_chip_dir
self.train_label_dir = train_label_dir
self.val_chip_dir = val_chip_dir
self.val_label_dir = val_label_dir
self.metadata = Box(yaml.safe_load(open(metadata_path)))
self.batch_size = batch_size
self.num_workers = num_workers
self.platform = platform

def setup(self, stage=None):
"""
Setup datasets for training and validation.
Args:
stage (str): Stage identifier ('fit' or 'test').
"""
if stage in {"fit", None}:
self.trn_ds = ChesapeakeDataset(
self.train_chip_dir,
self.train_label_dir,
self.metadata,
self.platform,
)
self.val_ds = ChesapeakeDataset(
self.val_chip_dir,
self.val_label_dir,
self.metadata,
self.platform,
)

def train_dataloader(self):
"""
Create DataLoader for training data.
Returns:
DataLoader: DataLoader for training dataset.
"""
return DataLoader(
self.trn_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)

def val_dataloader(self):
"""
Create DataLoader for validation data.
Returns:
DataLoader: DataLoader for validation dataset.
"""
return DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
)
Loading

0 comments on commit 85e821c

Please sign in to comment.