Skip to content

Commit

Permalink
Generate embeddings via prediction loop (#56)
Browse files Browse the repository at this point in the history
* 🍻 Generate embeddings via prediction loop

Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works.

* 🐛 Disable masking of patches on predict_step

Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768).

* ♻️ Refactor LightningDataModule to not do random split on predict

Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating.

* ✅ Test predict stage in geotiffdatamodule

Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages.

* 👔 Ensure that embeddings have no NaN values

Make sure that the generated embeddings do not have NaN values in them.

* 🗃️ Take mean of the embeddings along sequence_length dim

Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65).
  • Loading branch information
weiji14 authored Dec 4, 2023
1 parent 6f50653 commit 69ce703
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 23 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,9 @@ To train the model for a hundred epochs:

python trainer.py fit --trainer.max_epochs=100

To generate embeddings from the pretrained model's encoder on one image:

python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1 --trainer.limit_predict_batches=1

More options can be found using `python trainer.py fit --help`, or at the
[LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html).
46 changes: 34 additions & 12 deletions src/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,31 @@ def setup(self, stage: str | None = None):
root=self.data_path, masks="*.tif", recursive=True, length=423
)

# Step 2 - Split GeoTIFF chips into train/val sets (80%/20%)
# https://pytorch.org/data/0.7/generated/torchdata.datapipes.iter.RandomSplitter.html
dp_train, dp_val = dp_paths.random_split(
weights={"train": 0.8, "validation": 0.2}, total_length=423, seed=42
)
if stage == "fit": # training/validation loop
# Step 2 - Split GeoTIFF chips into train/val sets (80%/20%)
# https://pytorch.org/data/0.7/generated/torchdata.datapipes.iter.RandomSplitter.html
dp_train, dp_val = dp_paths.random_split(
weights={"train": 0.8, "validation": 0.2}, total_length=423, seed=42
)

# Step 3 - Read GeoTIFF into numpy.ndarray, batch and convert to torch.Tensor
self.datapipe_train = (
dp_train.map(fn=_array_to_torch).batch(batch_size=self.batch_size).collate()
)
self.datapipe_val = (
dp_val.map(fn=_array_to_torch).batch(batch_size=self.batch_size).collate()
)
# Step 3 - Read GeoTIFF into numpy array, batch and convert to torch.Tensor
self.datapipe_train = (
dp_train.map(fn=_array_to_torch)
.batch(batch_size=self.batch_size)
.collate()
)
self.datapipe_val = (
dp_val.map(fn=_array_to_torch)
.batch(batch_size=self.batch_size)
.collate()
)

elif stage == "predict": # prediction loop
self.datapipe_predict = (
dp_paths.map(fn=_array_to_torch)
.batch(batch_size=self.batch_size)
.collate()
)

def train_dataloader(self) -> torch.utils.data.DataLoader:
"""
Expand All @@ -103,3 +115,13 @@ def val_dataloader(self) -> torch.utils.data.DataLoader:
batch_size=None, # handled in datapipe already
num_workers=self.num_workers,
)

def predict_dataloader(self) -> torch.utils.data.DataLoader:
"""
Loads the data used in the prediction loop.
"""
return torch.utils.data.DataLoader(
dataset=self.datapipe_predict,
batch_size=None, # handled in datapipe already
num_workers=self.num_workers,
)
43 changes: 40 additions & 3 deletions src/model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
Code structure adapted from Lightning project seed at
https://github.com/Lightning-AI/deep-learning-project-template
"""
import os

import lightning as L
import numpy as np
import torch
import transformers

Expand Down Expand Up @@ -78,9 +81,6 @@ def forward(self, x: torch.Tensor) -> dict:
outputs: dict = self.vit.base_model(x)

self.B = x.shape[0]
assert outputs.last_hidden_state.shape == torch.Size([self.B, 17, 768])
assert outputs.ids_restore.shape == torch.Size([self.B, 64])
assert outputs.mask.shape == torch.Size([self.B, 64])

return outputs

Expand All @@ -96,6 +96,9 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:

# Forward encoder
outputs_encoder: dict = self(x)
assert outputs_encoder.last_hidden_state.shape == torch.Size([self.B, 17, 768])
assert outputs_encoder.ids_restore.shape == torch.Size([self.B, 64])
assert outputs_encoder.mask.shape == torch.Size([self.B, 64])

# Forward decoder
outputs_decoder: dict = self.vit.decoder.forward(
Expand Down Expand Up @@ -126,6 +129,40 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""
pass

def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""
Logic for the neural network's prediction loop.
"""
x: torch.Tensor = batch
# x: torch.Tensor = torch.randn(32, 13, 256, 256) # BCHW

# Forward encoder
self.vit.config.mask_ratio = 0 # disable masking
outputs_encoder: dict = self(x)

# Get embeddings generated from encoder
embeddings_raw: torch.Tensor = outputs_encoder.last_hidden_state
assert embeddings_raw.shape == torch.Size(
[self.B, 65, 768] # (batch_size, sequence_length, hidden_size)
)
assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding

# Take the mean of the embeddings along the sequence_length dimension
# excluding the first cls token embedding, compute over patch embeddings
embeddings_mean: torch.Tensor = embeddings_raw[:, 1:, :].mean(dim=1)
assert embeddings_mean.shape == torch.Size(
[self.B, 768] # (batch_size, hidden_size)
)

# Save embeddings in npy format
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings"
os.makedirs(name=outfolder, exist_ok=True)
outfile = f"{outfolder}/embedding_{batch_idx}.npy"
np.save(file=outfile, arr=embeddings_mean.cpu())
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outfile}")

return embeddings_mean

def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Optimizing function used to reduce the loss, so that the predicted
Expand Down
11 changes: 10 additions & 1 deletion src/tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,19 @@ def test_geotiffdatapipemodule(geotiff_folder):
datamodule: L.LightningDataModule = GeoTIFFDataPipeModule(
data_path=geotiff_folder, batch_size=2
)
datamodule.setup()

# Train/validation stage
datamodule.setup(stage="fit")
it = iter(datamodule.train_dataloader())
image = next(it)

assert image.shape == torch.Size([2, 3, 256, 256])
assert image.dtype == torch.float16

# Predict stage
datamodule.setup(stage="predict")
it = iter(datamodule.predict_dataloader())
image = next(it)

assert image.shape == torch.Size([2, 3, 256, 256])
assert image.dtype == torch.float16
29 changes: 22 additions & 7 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
Based loosely on Lightning's testing method described at
https://github.com/Lightning-AI/lightning/blob/2.1.0/.github/CONTRIBUTING.md#how-to-add-new-tests
"""
import os
import tempfile

import lightning as L
import numpy as np
import pytest
import torch
import torchdata
Expand Down Expand Up @@ -39,11 +43,22 @@ def test_model_vit(datapipe):
# Initialize model
model: L.LightningModule = ViTLitModule()

# Training
trainer: L.Trainer = L.Trainer(
accelerator="auto", devices=1, precision="16-mixed", fast_dev_run=True
)
trainer.fit(model=model, train_dataloaders=dataloader)
# Run tests in a temporary folder
with tempfile.TemporaryDirectory() as tmpdirname:
# Training
trainer: L.Trainer = L.Trainer(
accelerator="auto",
devices=1,
precision="16-mixed",
fast_dev_run=True,
default_root_dir=tmpdirname,
)
trainer.fit(model=model, train_dataloaders=dataloader)

# Test/Evaluation
# TODO
# Prediction
trainer.predict(model=model, dataloaders=dataloader)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embedding_0.npy")
embeddings: np.ndarray = np.load(file=path)
assert embeddings.shape == (2, 768)
assert embeddings.dtype == "float32"
assert not np.isnan(embeddings).any()

0 comments on commit 69ce703

Please sign in to comment.