From 69ce7031cd1c196ab46fdc249e867286ec26b3ca Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:51:44 +1300 Subject: [PATCH] Generate embeddings via prediction loop (#56) * :beers: Generate embeddings via prediction loop Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works. * :bug: Disable masking of patches on predict_step Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768). * :recycle: Refactor LightningDataModule to not do random split on predict Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating. * :white_check_mark: Test predict stage in geotiffdatamodule Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages. * :necktie: Ensure that embeddings have no NaN values Make sure that the generated embeddings do not have NaN values in them. * :card_file_box: Take mean of the embeddings along sequence_length dim Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65). --- README.md | 4 ++++ src/datamodule.py | 46 ++++++++++++++++++++++++++---------- src/model_vit.py | 43 ++++++++++++++++++++++++++++++--- src/tests/test_datamodule.py | 11 ++++++++- src/tests/test_model.py | 29 +++++++++++++++++------ 5 files changed, 110 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 3e051091..57c8c4fe 100644 --- a/README.md +++ b/README.md @@ -90,5 +90,9 @@ To train the model for a hundred epochs: python trainer.py fit --trainer.max_epochs=100 +To generate embeddings from the pretrained model's encoder on one image: + + python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1 --trainer.limit_predict_batches=1 + More options can be found using `python trainer.py fit --help`, or at the [LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html). diff --git a/src/datamodule.py b/src/datamodule.py index 582f28a6..ef7ecf59 100644 --- a/src/datamodule.py +++ b/src/datamodule.py @@ -70,19 +70,31 @@ def setup(self, stage: str | None = None): root=self.data_path, masks="*.tif", recursive=True, length=423 ) - # Step 2 - Split GeoTIFF chips into train/val sets (80%/20%) - # https://pytorch.org/data/0.7/generated/torchdata.datapipes.iter.RandomSplitter.html - dp_train, dp_val = dp_paths.random_split( - weights={"train": 0.8, "validation": 0.2}, total_length=423, seed=42 - ) + if stage == "fit": # training/validation loop + # Step 2 - Split GeoTIFF chips into train/val sets (80%/20%) + # https://pytorch.org/data/0.7/generated/torchdata.datapipes.iter.RandomSplitter.html + dp_train, dp_val = dp_paths.random_split( + weights={"train": 0.8, "validation": 0.2}, total_length=423, seed=42 + ) - # Step 3 - Read GeoTIFF into numpy.ndarray, batch and convert to torch.Tensor - self.datapipe_train = ( - dp_train.map(fn=_array_to_torch).batch(batch_size=self.batch_size).collate() - ) - self.datapipe_val = ( - dp_val.map(fn=_array_to_torch).batch(batch_size=self.batch_size).collate() - ) + # Step 3 - Read GeoTIFF into numpy array, batch and convert to torch.Tensor + self.datapipe_train = ( + dp_train.map(fn=_array_to_torch) + .batch(batch_size=self.batch_size) + .collate() + ) + self.datapipe_val = ( + dp_val.map(fn=_array_to_torch) + .batch(batch_size=self.batch_size) + .collate() + ) + + elif stage == "predict": # prediction loop + self.datapipe_predict = ( + dp_paths.map(fn=_array_to_torch) + .batch(batch_size=self.batch_size) + .collate() + ) def train_dataloader(self) -> torch.utils.data.DataLoader: """ @@ -103,3 +115,13 @@ def val_dataloader(self) -> torch.utils.data.DataLoader: batch_size=None, # handled in datapipe already num_workers=self.num_workers, ) + + def predict_dataloader(self) -> torch.utils.data.DataLoader: + """ + Loads the data used in the prediction loop. + """ + return torch.utils.data.DataLoader( + dataset=self.datapipe_predict, + batch_size=None, # handled in datapipe already + num_workers=self.num_workers, + ) diff --git a/src/model_vit.py b/src/model_vit.py index 664551a6..7047843e 100644 --- a/src/model_vit.py +++ b/src/model_vit.py @@ -4,7 +4,10 @@ Code structure adapted from Lightning project seed at https://github.com/Lightning-AI/deep-learning-project-template """ +import os + import lightning as L +import numpy as np import torch import transformers @@ -78,9 +81,6 @@ def forward(self, x: torch.Tensor) -> dict: outputs: dict = self.vit.base_model(x) self.B = x.shape[0] - assert outputs.last_hidden_state.shape == torch.Size([self.B, 17, 768]) - assert outputs.ids_restore.shape == torch.Size([self.B, 64]) - assert outputs.mask.shape == torch.Size([self.B, 64]) return outputs @@ -96,6 +96,9 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: # Forward encoder outputs_encoder: dict = self(x) + assert outputs_encoder.last_hidden_state.shape == torch.Size([self.B, 17, 768]) + assert outputs_encoder.ids_restore.shape == torch.Size([self.B, 64]) + assert outputs_encoder.mask.shape == torch.Size([self.B, 64]) # Forward decoder outputs_decoder: dict = self.vit.decoder.forward( @@ -126,6 +129,40 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: """ pass + def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: + """ + Logic for the neural network's prediction loop. + """ + x: torch.Tensor = batch + # x: torch.Tensor = torch.randn(32, 13, 256, 256) # BCHW + + # Forward encoder + self.vit.config.mask_ratio = 0 # disable masking + outputs_encoder: dict = self(x) + + # Get embeddings generated from encoder + embeddings_raw: torch.Tensor = outputs_encoder.last_hidden_state + assert embeddings_raw.shape == torch.Size( + [self.B, 65, 768] # (batch_size, sequence_length, hidden_size) + ) + assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding + + # Take the mean of the embeddings along the sequence_length dimension + # excluding the first cls token embedding, compute over patch embeddings + embeddings_mean: torch.Tensor = embeddings_raw[:, 1:, :].mean(dim=1) + assert embeddings_mean.shape == torch.Size( + [self.B, 768] # (batch_size, hidden_size) + ) + + # Save embeddings in npy format + outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings" + os.makedirs(name=outfolder, exist_ok=True) + outfile = f"{outfolder}/embedding_{batch_idx}.npy" + np.save(file=outfile, arr=embeddings_mean.cpu()) + print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outfile}") + + return embeddings_mean + def configure_optimizers(self) -> torch.optim.Optimizer: """ Optimizing function used to reduce the loss, so that the predicted diff --git a/src/tests/test_datamodule.py b/src/tests/test_datamodule.py index 49ed0db2..d5c83d9a 100644 --- a/src/tests/test_datamodule.py +++ b/src/tests/test_datamodule.py @@ -47,10 +47,19 @@ def test_geotiffdatapipemodule(geotiff_folder): datamodule: L.LightningDataModule = GeoTIFFDataPipeModule( data_path=geotiff_folder, batch_size=2 ) - datamodule.setup() + # Train/validation stage + datamodule.setup(stage="fit") it = iter(datamodule.train_dataloader()) image = next(it) assert image.shape == torch.Size([2, 3, 256, 256]) assert image.dtype == torch.float16 + + # Predict stage + datamodule.setup(stage="predict") + it = iter(datamodule.predict_dataloader()) + image = next(it) + + assert image.shape == torch.Size([2, 3, 256, 256]) + assert image.dtype == torch.float16 diff --git a/src/tests/test_model.py b/src/tests/test_model.py index ba4813c8..fb2ec652 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -4,7 +4,11 @@ Based loosely on Lightning's testing method described at https://github.com/Lightning-AI/lightning/blob/2.1.0/.github/CONTRIBUTING.md#how-to-add-new-tests """ +import os +import tempfile + import lightning as L +import numpy as np import pytest import torch import torchdata @@ -39,11 +43,22 @@ def test_model_vit(datapipe): # Initialize model model: L.LightningModule = ViTLitModule() - # Training - trainer: L.Trainer = L.Trainer( - accelerator="auto", devices=1, precision="16-mixed", fast_dev_run=True - ) - trainer.fit(model=model, train_dataloaders=dataloader) + # Run tests in a temporary folder + with tempfile.TemporaryDirectory() as tmpdirname: + # Training + trainer: L.Trainer = L.Trainer( + accelerator="auto", + devices=1, + precision="16-mixed", + fast_dev_run=True, + default_root_dir=tmpdirname, + ) + trainer.fit(model=model, train_dataloaders=dataloader) - # Test/Evaluation - # TODO + # Prediction + trainer.predict(model=model, dataloaders=dataloader) + assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embedding_0.npy") + embeddings: np.ndarray = np.load(file=path) + assert embeddings.shape == (2, 768) + assert embeddings.dtype == "float32" + assert not np.isnan(embeddings).any()