Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save embeddings with spatiotemporal metadata to GeoParquet #73

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ To train the model for a hundred epochs:

python trainer.py fit --trainer.max_epochs=100

To generate embeddings from the pretrained model's encoder on one image:
To generate embeddings from the pretrained model's encoder on 1024 images
(stored as a GeoParquet file with spatiotemporal metadata):

python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1 --trainer.limit_predict_batches=1
python trainer.py predict --ckpt_path=checkpoints/last.ckpt --data.batch_size=1000 --trainer.limit_predict_batches=1

More options can be found using `python trainer.py fit --help`, or at the
[LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html).
Expand Down
95 changes: 85 additions & 10 deletions src/model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
"""
import os

import geopandas as gpd
import lightning as L
import numpy as np
import pyarrow as pa
import shapely
import torch
import transformers

Expand Down Expand Up @@ -84,7 +87,9 @@ def forward(self, x: torch.Tensor) -> dict:

return outputs

def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
def training_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> torch.Tensor:
"""
Logic for the neural network's training loop.

Expand Down Expand Up @@ -123,18 +128,60 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:

return loss

def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
def validation_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> torch.Tensor:
"""
Logic for the neural network's validation loop.
"""
pass

def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
def predict_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> gpd.GeoDataFrame:
"""
Logic for the neural network's prediction loop.

Takes batches of image inputs, generate the embeddings, and store them
in a GeoParquet file with spatiotemporal metadata.

Steps:
1. Image inputs are passed through the encoder model to produce raw
embeddings of shape (B, 65, 768), where B is the batch size, 65 is
the dimension that consists of 1 cls_token + 64 patch embeddings
(that were flattened from the original 8x8 grid), and 768 is the
embedding length.
2. Taking only the (B, 64, 768) patch embeddings, we compute the mean
along the 64-dim, to obtain final embeddings of shape (B, 768).

______
cls_token / Patch / /|
embeddings / + embeddings /_____ / | => (1+64, 768)
(1, 768) / (8x8, 768) | | | = (65, 768)
/ = (64, 768) | | /
|______|/
| /
--------> Final embedding /
compute mean along spatial dim = (1, 768) /
/

3. Embeddings are joined with spatiotemporal metadata (date and
bounding box polygon) in a geopandas.GeoDataFrame table. The
coordinates of the bounding box are in an OGC:CRS84 projection (i.e.
longitude/latitude).
4. The geodataframe table is saved out to a GeoParquet file.

| date | embeddings | geometry |
|------------|----------------------|--------------|
| 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | ---> *.gpq
| 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) |
| 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
"""
x: torch.Tensor = batch["image"]
# x: torch.Tensor = torch.randn(32, 13, 256, 256) # BCHW
# Get image, bounding box, EPSG code, and date inputs
x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']

# Forward encoder
self.vit.config.mask_ratio = 0 # disable masking
Expand All @@ -154,14 +201,42 @@ def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
[self.B, 768] # (batch_size, hidden_size)
)

# Save embeddings in npy format
# Create table to store the embeddings with spatiotemporal metadata
unique_epsg_codes = set(int(epsg) for epsg in epsgs)
if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG
epsg: int = batch["epsg"][0]
else:
raise NotImplementedError(
f"More than 1 EPSG code detected: {unique_epsg_codes}"
)

gdf = gpd.GeoDataFrame(
data={
"date": gpd.pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
),
Comment on lines +218 to +220
Copy link
Contributor Author

@weiji14 weiji14 Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we've converted the embedding into a FixedShapeTensorArray here, pandas/geopandas still interprets this column as an object dtype, and this is saved as an object dtype to the parquet file too (see the unit test). Need to see if there's a way to preserve the dtype.

Copy link
Contributor Author

@weiji14 weiji14 Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a way to save this embeddings column as a FixedShapeTensorArray dtype instead of an object dtype like so:

Suggested change
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
),
"embeddings": gpd.pd.arrays.ArrowExtensionArray(
values=pa.FixedShapeTensorArray.from_numpy_ndarray(embeddings)
),

However, while we can save this FixedShapeTensorArray to GeoParquet, loading this embeddings column as a FixedShapeTensorArray is challenging, and might involve code that looks like this:

geodataframe: gpd.GeoDataFrame = gpd.read_parquet(
    path="data/embeddings/embeddings_0.gpq",
    schema=pa.schema(
        fields=[
            pa.field(
                name="embeddings",
                type=pa.fixed_shape_tensor(
                    value_type=pa.float32(), shape=[768]
                ),
            ),
            pa.field(name="geometry", type=pa.binary()),
        ]
    ),
)

But this technically still results in an embeddings column with object dtype... Also, QGIS can load this geoparquet file with FixedShapeTensorArray, but would crash when you try to open the attribute table, because it can't handle FixedShapeTensorArray yet. So probably best to keep it in object dtype for now.

},
geometry=shapely.box(
xmin=bboxes[:, 0],
ymin=bboxes[:, 1],
xmax=bboxes[:, 2],
ymax=bboxes[:, 3],
),
crs=f"EPSG:{epsg}",
)
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates

# Save embeddings in GeoParquet format
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings"
os.makedirs(name=outfolder, exist_ok=True)
outfile = f"{outfolder}/embedding_{batch_idx}.npy"
np.save(file=outfile, arr=embeddings_mean.cpu())
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outfile}")
outpath = f"{outfolder}/embeddings_{batch_idx}.gpq"
gdf.to_parquet(path=outpath, schema_version="1.0.0")
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}")
Comment on lines +235 to +237
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible to save several rows worth of embeddings to a single geoparquet file now. So, we can decide on how to lump embeddings together. E.g. save all the embeddings for one MGRS tile in one year together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New 512x512 image chips are being processed now-ish, see #76 (comment). Will use a new filename convention in a follow up PR (with the MGRS code in it) once we've got a new model trained on that new dataset.


return embeddings_mean
return gdf

def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Expand Down
32 changes: 25 additions & 7 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import tempfile

import geopandas as gpd
import lightning as L
import numpy as np
import pytest
Expand All @@ -25,8 +26,17 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:
"""
datapipe = torchdata.datapipes.iter.IterableWrapper(
iterable=[
{"image": torch.randn(2, 13, 256, 256).to(dtype=torch.float16)},
{"image": torch.randn(2, 13, 256, 256).to(dtype=torch.float16)},
{
"image": torch.randn(2, 13, 256, 256).to(dtype=torch.float16),
"bbox": torch.tensor(
data=[
[499975.0, 3397465.0, 502535.0, 3400025.0],
[530695.0, 3397465.0, 533255.0, 3400025.0],
]
),
"date": ["2020-01-01", "2020-12-31"],
"epsg": torch.tensor(data=[32646, 32646]),
},
]
)
return datapipe
Expand Down Expand Up @@ -57,8 +67,16 @@ def test_model_vit(datapipe):

# Prediction
trainer.predict(model=model, dataloaders=dataloader)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embedding_0.npy")
embeddings: np.ndarray = np.load(file=path)
assert embeddings.shape == (2, 768)
assert embeddings.dtype == "float32"
assert not np.isnan(embeddings).any()
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embeddings_0.gpq")
geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path)

assert geodataframe.shape == (2, 3)
assert all(geodataframe.columns == ["date", "embeddings", "geometry"])
assert geodataframe.date.dtype == "date32[day][pyarrow]"
assert geodataframe.embeddings.dtype == "object"
assert geodataframe.geometry.dtype == gpd.array.GeometryDtype()

for embeddings in geodataframe.embeddings:
assert embeddings.shape == (768,)
assert embeddings.dtype == "float32"
assert not np.isnan(embeddings).any()