Skip to content

Commit

Permalink
Add option to output raw patch embeddings (#133)
Browse files Browse the repository at this point in the history
* Add option to output raw patch embeddings

The patch embeddings are averages over the band groups.

* Fix test parametrization

* Fix test litmodel extra args only for clay

* Fix and improve einops conversion

* Document patch level embedding option

* Fix clay module check in test

* Add shuffle parameter to cli

* Document patch level embedding conversion

* Improve wording on patch level unravelling

* Fix argument construction in embeddigns docs

* Add vscode conf dir to gitignore

* Change embedding handling to level based approach

Allow for 3 levels: mean, patch, group. Arrays are flattened when passed to pandas.
This could be improved in the future.

* Enforce contiguous array before passing to pyarrow

* Remove shuffle argument

No longer necessary after #135
  • Loading branch information
yellowcap committed Jan 29, 2024
1 parent ae70345 commit 7a48658
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ docs/_build/
# Mac OS-specific storage files
.DS_Store
datadisk/

# vscode
.vscode
49 changes: 41 additions & 8 deletions docs/model_embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,27 @@ Step by step instructions to create embeddings for a single MGRS tile location
--trainer.precision=bf16-mixed \
--data.data_dir=s3://clay-tiles-02/02/27WXN \
--data.batch_size=32 \
--data.num_workers=16
--data.num_workers=16 \
--model.embeddings_level=group
```

This should output a GeoParquet file containing the embeddings for MGRS tile
27WXN (recall that each 10000x10000 pixel MGRS tile contains hundreds of
smaller 512x512 chips), saved to the `data/embeddings/` folder. See the next
sub-section for details about the embeddings file.

The `embeddings_level` flag determines how the embeddings are calculated.
The default is `mean`, resulting in one average embedding per MGRS tile of
size 768. If set to `patch`, the embeddings will be kept at the patch level.
The embedding array will be of size 16 * 16 * 768, representing one
embedding per patch. The third option `group` will keep the full
dimensionality of the encoder output, including the band group
dimension. The array size of those embeddings is 6 * 16 * 16 * 768.

The embeddings are flattened into one dimensional arrays because pandas
does not allow for multidimensional arrays. This makes it necessary to
reshape the flattened arrays to access the patch level embeddings.

```{note}
For those interested in how the embeddings were computed, the predict step
above does the following:
Expand All @@ -62,8 +75,9 @@ Step by step instructions to create embeddings for a single MGRS tile location
dimension itself is a concatenation of 1536 (6 band groups x 16x16
spatial patches of size 32x32 pixels each in a 512x512 image) + 2 (latlon
embedding and time embedding) = 1538.
2. The mean or average is taken across the 1536 patch dimension, yielding an
output embedding of shape (B, 768).
2. By default, the mean or average is taken across the 1536 patch dimension,
yielding an output embedding of shape (B, 768). If patch embeddings are
requested, the shape is (B, 16 * 16 * 768), one embedding per patch.
More details of how this is implemented can be found by inspecting the
`predict_step` method in the `model_clay.py` file.
Expand Down Expand Up @@ -104,11 +118,15 @@ and contains a record of the embeddings, spatiotemporal metadata, and a link to
the GeoTIFF file used as the source image for the embedding. The table looks
something like this:
| source_url | date | embeddings | geometry |
|-----------------------------|------------|----------------------|--------------|
| s3://.../.../claytile_*.tif | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) |
| s3://.../.../claytile_*.tif | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) |
| s3://.../.../claytile_*.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
Embedding size is 768 by default, 16 * 16 * 768 for patch level embeddings, and
6 * 16 * 16 * 768 for group level embeddings.
| source_url | date | embeddings | geometry |
|-----------------------------|------------|------------------|--------------|
| s3://.../.../claytile_*.tif | 2021-01-01 | [0.1, 0.4, ... ] | POLYGON(...) |
| s3://.../.../claytile_*.tif | 2021-06-30 | [0.2, 0.5, ... ] | POLYGON(...) |
| s3://.../.../claytile_*.tif | 2021-12-31 | [0.3, 0.6, ... ] | POLYGON(...) |
Details of each column are as follows:
Expand Down Expand Up @@ -142,3 +160,18 @@ Further reading:
- https://guide.cloudnativegeo.org/geoparquet
- https://cloudnativegeo.org/blog/2023/10/the-geoparquet-ecosystem-at-1.0.0
```

## Converting to patch level embeddings

In the case where patch level embeddings are requested, the resulting array
will have all patch embeddings ravelled in one row. Each row represents a
512x512 pixel image, and contains 16x16 patch embeddings.

To convert each row into patch level embeddings, the embedding array has to
be unravelled into 256 patches like so

```{code}
# This assumes embeddings levels set to "patch"
ravelled_patch_embeddings = geodataframe.embeddings[0]
patch_embeddings = ravelled_patch_embeddings.reshape(16, 16, 768)
```
49 changes: 41 additions & 8 deletions src/model_clay.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from typing import Literal

import geopandas as gpd
import lightning as L
Expand Down Expand Up @@ -796,6 +797,7 @@ def __init__( # noqa: PLR0913
wd=0.05,
b1=0.9,
b2=0.95,
embeddings_level: Literal["mean", "patch", "group"] = "mean",
):
super().__init__()
self.save_hyperparameters(logger=True)
Expand Down Expand Up @@ -887,13 +889,44 @@ def predict_step(
)
assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding

# Take the mean of the embeddings along the sequence_length dimension
# excluding the last two latlon_ and time_ embeddings, i.e. compute
# mean over patch embeddings only
embeddings_mean: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1)
assert embeddings_mean.shape == torch.Size(
[self.model.encoder.B, 768] # (batch_size, hidden_size)
)
if self.hparams.embeddings_level == "mean":
# Take the mean of the embeddings along the sequence_length dimension
# excluding the last two latlon_ and time_ embeddings, i.e. compute
# mean over patch embeddings only
embeddings_output: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1)
expected_size = [self.model.encoder.B, 768] # (batch_size, hidden_size)
elif self.hparams.embeddings_level in ["patch", "group"]:
# Take the mean of the embeddings along the group dimension
# excluding the last two latlon_ and time_ embeddings. This
# results in one embedding per patch.
embeddings_output = rearrange(
embeddings_raw[:, :-2, :], "b (g h w) d -> b g h w d", w=16, h=16, g=6
)
if self.hparams.embeddings_level == "patch":
embeddings_output = reduce(
embeddings_output, "b g h w d -> b h w d", "mean"
)
expected_size = [
self.model.encoder.B,
16,
16,
768,
]
else:
expected_size = [
self.model.encoder.B,
6,
16,
16,
768,
]
else:
raise ValueError(
f"Value {self.hparams.embeddings_level} no allowed. "
"Choose one from mean, patch, or group"
)

assert embeddings_output.shape == torch.Size(expected_size)

# Create table to store the embeddings with spatiotemporal metadata
unique_epsg_codes = set(int(epsg) for epsg in epsgs)
Expand All @@ -911,7 +944,7 @@ def predict_step(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
np.ascontiguousarray(embeddings_output.cpu().detach().__array__())
),
},
geometry=shapely.box(
Expand Down
28 changes: 23 additions & 5 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,27 @@ def test_model_vit_fit(datapipe):
@pytest.mark.parametrize(
"litmodule,precision",
[
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"),
(ViTLitModule, "bf16-mixed"),
(CLAYModule, "16-mixed" if torch.cuda.is_available() else "32-true"),
(ViTLitModule, "16-mixed"),
],
)
def test_model_predict(datapipe, litmodule, precision):
@pytest.mark.parametrize("embeddings_level", ["mean", "patch", "group"])
def test_model_predict(datapipe, litmodule, precision, embeddings_level):
"""
Run a single prediction loop using 1 batch.
"""
# Get some random data
dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe)

# Initialize model
model: L.LightningModule = litmodule()
if litmodule == CLAYModule:
litargs = {
"embeddings_level": embeddings_level,
}
else:
litargs = {}

model: L.LightningModule = litmodule(**litargs)

# Run tests in a temporary folder
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down Expand Up @@ -139,7 +147,17 @@ def test_model_predict(datapipe, litmodule, precision):
assert geodataframe.embeddings.dtype == "object"
assert geodataframe.geometry.dtype == gpd.array.GeometryDtype()

expected_shape_lookup = {
"mean": (768,),
"patch": (16 * 16 * 768,),
"group": (6 * 16 * 16 * 768,),
}

for embeddings in geodataframe.embeddings:
assert embeddings.shape == (768,)
assert (
embeddings.shape == expected_shape_lookup[embeddings_level]
if litmodule == CLAYModule
else (768,)
)
assert embeddings.dtype == "float32"
assert not np.isnan(embeddings).any()

0 comments on commit 7a48658

Please sign in to comment.