Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 1, 2024
1 parent be3cc33 commit 49bf34a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 83 deletions.
4 changes: 2 additions & 2 deletions scripts/worldcover/embeddings_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
db.table_names()

# Drop existing table if exists
#db.drop_table("worldcover-2020-v001")
# db.drop_table("worldcover-2020-v001")

# Create embeddings table and insert the vector data
tbl = db.create_table("worldcover-2020-v001", data=data, mode="overwrite")
Expand All @@ -55,4 +55,4 @@ def plot(df, cols=10):
# Select a vector by index, and search 10 similar pairs, and plot
v = tbl.to_pandas()["vector"].values[5]
result = tbl.search(query=v).limit(5).to_pandas()
plot(result, 5)
plot(result, 5)
89 changes: 46 additions & 43 deletions scripts/worldcover/run.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
#!/usr/bin/env python3

import sys

sys.path.append("../../")

import os
import tempfile
from math import floor
from pathlib import Path
import requests

import boto3
import einops
import geopandas as gpd
import pandas as pd
import numpy
import pyarrow as pa
import pandas as pd
import rasterio
import requests
import shapely
import torch
import xarray as xr
from rasterio.windows import Window
from shapely import box
from torchvision.transforms import v2

from src.datamodule import ClayDataset
Expand Down Expand Up @@ -141,6 +140,7 @@ def tiles_and_windows(input: Window):

return result


def download_image(url):
# Download an image from a URL
response = requests.get(url)
Expand All @@ -150,52 +150,54 @@ def download_image(url):
else:
raise Exception("Failed to download the image")


def patch_bounds_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
# Download an image from a URL
image_data = download_image(url)

# Open the image using rasterio from memory
with rasterio.io.MemoryFile(image_data) as memfile:
with memfile.open() as src:
# Read the image data and metadata
img_data = src.read()
img_meta = src.profile
img_crs = src.crs

# Convert raster data and metadata into an xarray DataArray
img_da = xr.DataArray(img_data, dims=("band", "y", "x"), attrs=img_meta)

# Tile the data
ds_chunked = img_da.chunk({"y": chunk_size[0], "x": chunk_size[1]})

# Get the geospatial information from the original dataset
transform = img_meta["transform"]

# Iterate over the chunks and compute the geospatial bounds for each chunk
chunk_bounds = {}

for x in range(ds_chunked.sizes["x"] // chunk_size[1]):
for y in range(ds_chunked.sizes["y"] // chunk_size[0]):
# Compute chunk coordinates
x_start = x * chunk_size[1]
y_start = y * chunk_size[0]
x_end = min(x_start + chunk_size[1], ds_chunked.sizes["x"])
y_end = min(y_start + chunk_size[0], ds_chunked.sizes["y"])

# Compute chunk geospatial bounds
lon_start, lat_start = transform * (x_start, y_start)
lon_end, lat_end = transform * (x_end, y_end)

# Store chunk bounds
chunk_bounds[(x, y)] = {
"lon_start": lon_start,
"lat_start": lat_start,
"lon_end": lon_end,
"lat_end": lat_end,
}

return chunk_bounds, img_crs


def make_batch(result):
pixels = []
for url, win in result:
Expand Down Expand Up @@ -230,10 +232,10 @@ def make_batch(result):
"timestep": torch.as_tensor(data=[ds.normalize_timestamp(f"{YEAR}-06-01")]).to(
rgb_model.device
),
"date": f"{YEAR}-06-01"
,
"date": f"{YEAR}-06-01",
}


def get_pixels(result):
pixels = []
for url, win in result:
Expand Down Expand Up @@ -319,42 +321,41 @@ def get_pixels(result):
)

yoff += CHIP_SIZE



print(len(embeddings), len(results))
embeddings_ = numpy.vstack(embeddings)
#embeddings_ = embeddings[0]
# embeddings_ = embeddings[0]
print("Embeddings shape: ", embeddings_.shape)

# remove date and lat/lon
embeddings_ = embeddings_[:, :-2, :].mean(axis=0)

print(f"Embeddings have shape {embeddings_.shape}")

# reshape to disaggregated patches
embeddings_patch = embeddings_.reshape([2, 16, 16, 768])

# average over the band groups
embeddings_mean = embeddings_patch.mean(axis=0)

print(f"Average patch embeddings have shape {embeddings_mean.shape}")

print(f"Average patch embeddings have shape {embeddings_mean.shape}")

if result is not None:
print("result: ", result[0][0])
pix = get_pixels(result)
chunk_bounds, epsg = patch_bounds_from_url(result[0][0])
#print("chunk_bounds: ", chunk_bounds)
# print("chunk_bounds: ", chunk_bounds)
print("chunk bounds length:", len(chunk_bounds))

# Iterate through each patch
for i in range(embeddings_mean.shape[0]):
for j in range(embeddings_mean.shape[1]):
embeddings_output_patch = embeddings_mean[i, j]

item_ = [
element for element in list(chunk_bounds.items()) if element[0] == (i, j)
element
for element in list(chunk_bounds.items())
if element[0] == (i, j)
]
box_ = [
item_[0][1]["lon_start"],
Expand All @@ -364,42 +365,44 @@ def get_pixels(result):
]

data = {
#"source_url": batch["source_url"][0],
#"date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
# "source_url": batch["source_url"][0],
# "date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
# dtype="date32[day][pyarrow]"
#),
#"date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
# ),
# "date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
"date": pd.to_datetime(batch["date"], format="%Y-%m-%d"),
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
}

# Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
# The box_ list is encoded as
# [bottom left x, bottom left y, top right x, top right y]
box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])

print(str(epsg)[-4:])

# Create the GeoDataFrame
gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}")

gdf = gpd.GeoDataFrame(
data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}"
)

# Reproject to WGS84 (lon/lat coordinates)
gdf = gdf.to_crs(epsg=4326)

with tempfile.TemporaryDirectory() as tmp:
# tmp = "/home/tam/Desktop/wcctmp"

outpath = f"{tmp}/worldcover_patch_embeddings_{YEAR}_{index}_{i}_{j}_v{VERSION}.gpq"
print(f"Uploading embeddings to {outpath}")
#print(gdf)

gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")

# print(gdf)

gdf.to_parquet(
path=outpath, compression="ZSTD", schema_version="1.0.0"
)

s3_client = boto3.client("s3")
s3_client.upload_file(
outpath,
BUCKET,
f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}",
)


Loading

0 comments on commit 49bf34a

Please sign in to comment.