-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save embeddings with spatiotemporal metadata to GeoParquet #73
Changes from all commits
a767164
dd84c2c
f743b53
384650c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,11 @@ | |
""" | ||
import os | ||
|
||
import geopandas as gpd | ||
import lightning as L | ||
import numpy as np | ||
import pyarrow as pa | ||
import shapely | ||
import torch | ||
import transformers | ||
|
||
|
@@ -84,7 +87,9 @@ def forward(self, x: torch.Tensor) -> dict: | |
|
||
return outputs | ||
|
||
def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: | ||
def training_step( | ||
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int | ||
) -> torch.Tensor: | ||
""" | ||
Logic for the neural network's training loop. | ||
|
||
|
@@ -123,18 +128,60 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: | |
|
||
return loss | ||
|
||
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: | ||
def validation_step( | ||
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int | ||
) -> torch.Tensor: | ||
""" | ||
Logic for the neural network's validation loop. | ||
""" | ||
pass | ||
|
||
def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: | ||
def predict_step( | ||
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int | ||
) -> gpd.GeoDataFrame: | ||
""" | ||
Logic for the neural network's prediction loop. | ||
|
||
Takes batches of image inputs, generate the embeddings, and store them | ||
in a GeoParquet file with spatiotemporal metadata. | ||
|
||
Steps: | ||
1. Image inputs are passed through the encoder model to produce raw | ||
embeddings of shape (B, 65, 768), where B is the batch size, 65 is | ||
the dimension that consists of 1 cls_token + 64 patch embeddings | ||
(that were flattened from the original 8x8 grid), and 768 is the | ||
embedding length. | ||
2. Taking only the (B, 64, 768) patch embeddings, we compute the mean | ||
along the 64-dim, to obtain final embeddings of shape (B, 768). | ||
|
||
______ | ||
cls_token / Patch / /| | ||
embeddings / + embeddings /_____ / | => (1+64, 768) | ||
(1, 768) / (8x8, 768) | | | = (65, 768) | ||
/ = (64, 768) | | / | ||
|______|/ | ||
| / | ||
--------> Final embedding / | ||
compute mean along spatial dim = (1, 768) / | ||
/ | ||
|
||
3. Embeddings are joined with spatiotemporal metadata (date and | ||
bounding box polygon) in a geopandas.GeoDataFrame table. The | ||
coordinates of the bounding box are in an OGC:CRS84 projection (i.e. | ||
longitude/latitude). | ||
4. The geodataframe table is saved out to a GeoParquet file. | ||
|
||
| date | embeddings | geometry | | ||
|------------|----------------------|--------------| | ||
| 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | ---> *.gpq | ||
| 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) | | ||
| 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) | | ||
""" | ||
x: torch.Tensor = batch["image"] | ||
# x: torch.Tensor = torch.randn(32, 13, 256, 256) # BCHW | ||
# Get image, bounding box, EPSG code, and date inputs | ||
x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW | ||
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes | ||
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code | ||
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12'] | ||
|
||
# Forward encoder | ||
self.vit.config.mask_ratio = 0 # disable masking | ||
|
@@ -154,14 +201,42 @@ def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: | |
[self.B, 768] # (batch_size, hidden_size) | ||
) | ||
|
||
# Save embeddings in npy format | ||
# Create table to store the embeddings with spatiotemporal metadata | ||
unique_epsg_codes = set(int(epsg) for epsg in epsgs) | ||
if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG | ||
epsg: int = batch["epsg"][0] | ||
else: | ||
raise NotImplementedError( | ||
f"More than 1 EPSG code detected: {unique_epsg_codes}" | ||
) | ||
|
||
gdf = gpd.GeoDataFrame( | ||
data={ | ||
"date": gpd.pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( | ||
dtype="date32[day][pyarrow]" | ||
), | ||
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( | ||
embeddings_mean.cpu().detach().__array__() | ||
), | ||
}, | ||
geometry=shapely.box( | ||
xmin=bboxes[:, 0], | ||
ymin=bboxes[:, 1], | ||
xmax=bboxes[:, 2], | ||
ymax=bboxes[:, 3], | ||
), | ||
crs=f"EPSG:{epsg}", | ||
) | ||
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates | ||
|
||
# Save embeddings in GeoParquet format | ||
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings" | ||
os.makedirs(name=outfolder, exist_ok=True) | ||
outfile = f"{outfolder}/embedding_{batch_idx}.npy" | ||
np.save(file=outfile, arr=embeddings_mean.cpu()) | ||
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outfile}") | ||
outpath = f"{outfolder}/embeddings_{batch_idx}.gpq" | ||
gdf.to_parquet(path=outpath, schema_version="1.0.0") | ||
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}") | ||
Comment on lines
+235
to
+237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is possible to save several rows worth of embeddings to a single geoparquet file now. So, we can decide on how to lump embeddings together. E.g. save all the embeddings for one MGRS tile in one year together. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New 512x512 image chips are being processed now-ish, see #76 (comment). Will use a new filename convention in a follow up PR (with the MGRS code in it) once we've got a new model trained on that new dataset. |
||
|
||
return embeddings_mean | ||
return gdf | ||
|
||
def configure_optimizers(self) -> torch.optim.Optimizer: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although we've converted the embedding into a FixedShapeTensorArray here,
pandas
/geopandas
still interprets this column as anobject
dtype, and this is saved as an object dtype to the parquet file too (see the unit test). Need to see if there's a way to preserve the dtype.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found a way to save this
embeddings
column as a FixedShapeTensorArray dtype instead of anobject
dtype like so:However, while we can save this
FixedShapeTensorArray
to GeoParquet, loading this embeddings column as aFixedShapeTensorArray
is challenging, and might involve code that looks like this:But this technically still results in an
embeddings
column withobject
dtype... Also, QGIS can load this geoparquet file with FixedShapeTensorArray, but would crash when you try to open the attribute table, because it can't handle FixedShapeTensorArray yet. So probably best to keep it inobject
dtype for now.