Skip to content

Commit

Permalink
Save embeddings with spatiotemporal metadata to GeoParquet (#73)
Browse files Browse the repository at this point in the history
* ✨ Save embeddings with spatiotemporal metadata to GeoParquet

Storing the vector embeddings alongside some spatial bounding box and datetime information in a tabular GeoParquet format, instead of an npy file! Using geopandas to create a GeoDataFrame with three columns - date, embeddings, geometry. The date is stored in Arrow's date32 format, embeddings are in FixedShapedTensorArray, and geometry is in WKB. Have updated the unit test's sample fixture data with the extra spatiotemporal data, and tested that the saved GeoParquet file can be loaded back.

* 📝 Document how embeddings are generated and saved to geoparquet

Improve the docstring of predict_step in the LightningModule on how the embeddings are generated, and then saved to a GeoParquet file with the spatiotemporal metadata. Included some ASCII art and a markdown table of how the tabular data looks like.

* 📝 Mention in main README.md that embeddings are saved to geoparquet

Document that the embeddings are stored with spatiotemporal metadata as a GeoParquet file. Increased batch size from 1 to 1024.

* 🎨 Update type hint of batch inputs, and add some inline comments

Should have updated the type hints in #66, but might as well do it here. Also adding some more inline comments and fixed a typo.
  • Loading branch information
weiji14 committed Dec 8, 2023
1 parent c6a8365 commit decea30
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 19 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ To train the model for a hundred epochs:

python trainer.py fit --trainer.max_epochs=100

To generate embeddings from the pretrained model's encoder on one image:
To generate embeddings from the pretrained model's encoder on 1024 images
(stored as a GeoParquet file with spatiotemporal metadata):

python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1 --trainer.limit_predict_batches=1
python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1000 --trainer.limit_predict_batches=1

More options can be found using `python trainer.py fit --help`, or at the
[LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html).
Expand Down
95 changes: 85 additions & 10 deletions src/model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
"""
import os

import geopandas as gpd
import lightning as L
import numpy as np
import pyarrow as pa
import shapely
import torch
import transformers

Expand Down Expand Up @@ -84,7 +87,9 @@ def forward(self, x: torch.Tensor) -> dict:

return outputs

def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
def training_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> torch.Tensor:
"""
Logic for the neural network's training loop.
Expand Down Expand Up @@ -123,18 +128,60 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:

return loss

def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
def validation_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> torch.Tensor:
"""
Logic for the neural network's validation loop.
"""
pass

def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
def predict_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> gpd.GeoDataFrame:
"""
Logic for the neural network's prediction loop.
Takes batches of image inputs, generate the embeddings, and store them
in a GeoParquet file with spatiotemporal metadata.
Steps:
1. Image inputs are passed through the encoder model to produce raw
embeddings of shape (B, 65, 768), where B is the batch size, 65 is
the dimension that consists of 1 cls_token + 64 patch embeddings
(that were flattened from the original 8x8 grid), and 768 is the
embedding length.
2. Taking only the (B, 64, 768) patch embeddings, we compute the mean
along the 64-dim, to obtain final embeddings of shape (B, 768).
______
cls_token / Patch / /|
embeddings / + embeddings /_____ / | => (1+64, 768)
(1, 768) / (8x8, 768) | | | = (65, 768)
/ = (64, 768) | | /
|______|/
| /
--------> Final embedding /
compute mean along spatial dim = (1, 768) /
/
3. Embeddings are joined with spatiotemporal metadata (date and
bounding box polygon) in a geopandas.GeoDataFrame table. The
coordinates of the bounding box are in an OGC:CRS84 projection (i.e.
longitude/latitude).
4. The geodataframe table is saved out to a GeoParquet file.
| date | embeddings | geometry |
|------------|----------------------|--------------|
| 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | ---> *.gpq
| 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) |
| 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
"""
x: torch.Tensor = batch["image"]
# x: torch.Tensor = torch.randn(32, 13, 256, 256) # BCHW
# Get image, bounding box, EPSG code, and date inputs
x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']

# Forward encoder
self.vit.config.mask_ratio = 0 # disable masking
Expand All @@ -154,14 +201,42 @@ def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
[self.B, 768] # (batch_size, hidden_size)
)

# Save embeddings in npy format
# Create table to store the embeddings with spatiotemporal metadata
unique_epsg_codes = set(int(epsg) for epsg in epsgs)
if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG
epsg: int = batch["epsg"][0]
else:
raise NotImplementedError(
f"More than 1 EPSG code detected: {unique_epsg_codes}"
)

gdf = gpd.GeoDataFrame(
data={
"date": gpd.pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
),
},
geometry=shapely.box(
xmin=bboxes[:, 0],
ymin=bboxes[:, 1],
xmax=bboxes[:, 2],
ymax=bboxes[:, 3],
),
crs=f"EPSG:{epsg}",
)
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates

# Save embeddings in GeoParquet format
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings"
os.makedirs(name=outfolder, exist_ok=True)
outfile = f"{outfolder}/embedding_{batch_idx}.npy"
np.save(file=outfile, arr=embeddings_mean.cpu())
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outfile}")
outpath = f"{outfolder}/embeddings_{batch_idx}.gpq"
gdf.to_parquet(path=outpath, schema_version="1.0.0")
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}")

return embeddings_mean
return gdf

def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Expand Down
32 changes: 25 additions & 7 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import tempfile

import geopandas as gpd
import lightning as L
import numpy as np
import pytest
Expand All @@ -25,8 +26,17 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:
"""
datapipe = torchdata.datapipes.iter.IterableWrapper(
iterable=[
{"image": torch.randn(2, 13, 256, 256).to(dtype=torch.float16)},
{"image": torch.randn(2, 13, 256, 256).to(dtype=torch.float16)},
{
"image": torch.randn(2, 13, 256, 256).to(dtype=torch.float16),
"bbox": torch.tensor(
data=[
[499975.0, 3397465.0, 502535.0, 3400025.0],
[530695.0, 3397465.0, 533255.0, 3400025.0],
]
),
"date": ["2020-01-01", "2020-12-31"],
"epsg": torch.tensor(data=[32646, 32646]),
},
]
)
return datapipe
Expand Down Expand Up @@ -57,8 +67,16 @@ def test_model_vit(datapipe):

# Prediction
trainer.predict(model=model, dataloaders=dataloader)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embedding_0.npy")
embeddings: np.ndarray = np.load(file=path)
assert embeddings.shape == (2, 768)
assert embeddings.dtype == "float32"
assert not np.isnan(embeddings).any()
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embeddings_0.gpq")
geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path)

assert geodataframe.shape == (2, 3)
assert all(geodataframe.columns == ["date", "embeddings", "geometry"])
assert geodataframe.date.dtype == "date32[day][pyarrow]"
assert geodataframe.embeddings.dtype == "object"
assert geodataframe.geometry.dtype == gpd.array.GeometryDtype()

for embeddings in geodataframe.embeddings:
assert embeddings.shape == (768,)
assert embeddings.dtype == "float32"
assert not np.isnan(embeddings).any()

0 comments on commit decea30

Please sign in to comment.