diff --git a/docs/patch_level_cloud_cover.ipynb b/docs/patch_level_cloud_cover.ipynb index 5c229958..89b98404 100644 --- a/docs/patch_level_cloud_cover.ipynb +++ b/docs/patch_level_cloud_cover.ipynb @@ -23,29 +23,27 @@ "metadata": {}, "outputs": [], "source": [ + "import glob\n", + "from pathlib import Path\n", + "\n", "import geopandas as gpd\n", - "import pystac_client\n", - "import shapely\n", - "import stackstac\n", - "import torch\n", + "import lancedb\n", "import matplotlib.pyplot as plt\n", "import numpy\n", "import pandas as pd\n", - "import xarray as xr\n", + "import pystac_client\n", "import rasterio\n", "import rioxarray # noqa: F401\n", - "import pyarrow as pa\n", - "import pickle\n", - "import lancedb\n", - "import glob\n", - "from pathlib import Path\n", - "from shapely.geometry import Point, Polygon, box\n", - "from rasterio.enums import Resampling\n", + "import shapely\n", + "import stackstac\n", + "import torch\n", "from rasterio.enums import Resampling\n", + "from shapely.geometry import Polygon, box\n", + "\n", "from src.datamodule import ClayDataModule\n", "from src.model_clay import CLAYModule\n", "\n", - "pd.set_option('display.max_colwidth', None)\n", + "pd.set_option(\"display.max_colwidth\", None)\n", "\n", "BAND_GROUPS_L2A = {\n", " \"rgb\": [\"red\", \"green\", \"blue\"],\n", @@ -89,10 +87,10 @@ "outputs": [], "source": [ "# sample cluster\n", - "bbox_bl = (177.4199,-17.8579)\n", - "bbox_tl = (177.4156,-17.6812)\n", - "bbox_br = (177.5657,-17.8572)\n", - "bbox_tr = (177.5657,-17.6812)" + "bbox_bl = (177.4199, -17.8579)\n", + "bbox_tl = (177.4156, -17.6812)\n", + "bbox_br = (177.5657, -17.8572)\n", + "bbox_tr = (177.5657, -17.6812)" ] }, { @@ -112,7 +110,8 @@ "source": [ "# Define area of interest\n", "area_of_interest = shapely.box(\n", - " xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1])\n", + " xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1]\n", + ")\n", "\n", "# Define temporal range\n", "daterange: dict = [\"2021-01-01T00:00:00Z\", \"2021-12-31T23:59:59Z\"]" @@ -160,8 +159,9 @@ "epsg = items_L2A[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point from lon/lat to UTM projection\n", - "poidf = gpd.GeoDataFrame(crs=\"OGC:CRS84\", \n", - " geometry=[area_of_interest.centroid]).to_crs(epsg)\n", + "poidf = gpd.GeoDataFrame(crs=\"OGC:CRS84\", geometry=[area_of_interest.centroid]).to_crs(\n", + " epsg\n", + ")\n", "geom = poidf.iloc[0].geometry\n", "\n", "# Create bounds of the correct size, the model\n", @@ -182,7 +182,7 @@ " fill_value=0,\n", " assets=BAND_GROUPS_L2A[\"rgb\"] + BAND_GROUPS_L2A[\"scl\"],\n", " resampling=Resampling.nearest,\n", - " xy_coords='center',\n", + " xy_coords=\"center\",\n", ")\n", "\n", "stack_L2A = stack_L2A.compute()\n", @@ -217,13 +217,13 @@ " # Write tile to output dir, whilst dropping the SCL band in the process\n", " for tile in stack_L2A.sel(band=[\"red\", \"green\", \"blue\"]):\n", " date = str(tile.time.values)[:10]\n", - " \n", + "\n", " name = \"{dir}/claytile_{date}.tif\".format(\n", " dir=outdir,\n", " date=date.replace(\"-\", \"\"),\n", " )\n", " tile.rio.to_raster(name, compress=\"deflate\")\n", - " \n", + "\n", " with rasterio.open(name, \"r+\") as rst:\n", " rst.update_tags(date=date)" ] @@ -246,6 +246,7 @@ "source": [ "# Function to count cloud pixels in a subset\n", "\n", + "\n", "def count_cloud_pixels(subset_scl, cloud_labels):\n", " cloud_pixels = 0\n", " for label in cloud_labels:\n", @@ -261,7 +262,7 @@ "outputs": [], "source": [ "# Define the chunk size for tiling\n", - "chunk_size = {'x': 32, 'y': 32} # Adjust the chunk size as needed\n", + "chunk_size = {\"x\": 32, \"y\": 32} # Adjust the chunk size as needed\n", "\n", "# Tile the data\n", "ds_chunked_L2A = stack_L2A.chunk(chunk_size)\n", @@ -279,47 +280,48 @@ "cloud_pcts = {}\n", "\n", "# Get the geospatial transform and CRS\n", - "transform = ds_chunked_L2A.attrs['transform']\n", - "crs = ds_chunked_L2A.attrs['crs']\n", + "transform = ds_chunked_L2A.attrs[\"transform\"]\n", + "crs = ds_chunked_L2A.attrs[\"crs\"]\n", "\n", - "for x in range((ds_chunked_L2A.sizes['x'] // chunk_size['x'])): # + 1):\n", - " for y in range((ds_chunked_L2A.sizes['y'] // chunk_size['y'])): # + 1):\n", + "for x in range(ds_chunked_L2A.sizes[\"x\"] // chunk_size[\"x\"]): # + 1):\n", + " for y in range(ds_chunked_L2A.sizes[\"y\"] // chunk_size[\"y\"]): # + 1):\n", " # Compute chunk coordinates\n", - " x_start = x * chunk_size['x']\n", - " y_start = y * chunk_size['y']\n", - " x_end = min(x_start + chunk_size['x'], ds_chunked_L2A.sizes['x'])\n", - " y_end = min(y_start + chunk_size['y'], ds_chunked_L2A.sizes['y'])\n", - " \n", + " x_start = x * chunk_size[\"x\"]\n", + " y_start = y * chunk_size[\"y\"]\n", + " x_end = min(x_start + chunk_size[\"x\"], ds_chunked_L2A.sizes[\"x\"])\n", + " y_end = min(y_start + chunk_size[\"y\"], ds_chunked_L2A.sizes[\"y\"])\n", + "\n", " # Compute chunk geospatial bounds\n", " lon_start, lat_start = transform * (x_start, y_start)\n", " lon_end, lat_end = transform * (x_end, y_end)\n", - " #print(lon_start, lat_start, lon_end, lat_end, x, y)\n", + " # print(lon_start, lat_start, lon_end, lat_end, x, y)\n", "\n", " # Store chunk bounds\n", " chunk_bounds[(x, y)] = {\n", - " 'lon_start': lon_start, 'lat_start': lat_start,\n", - " 'lon_end': lon_end, 'lat_end': lat_end\n", + " \"lon_start\": lon_start,\n", + " \"lat_start\": lat_start,\n", + " \"lon_end\": lon_end,\n", + " \"lat_end\": lat_end,\n", " }\n", "\n", " # Extract the subset of the SCL band\n", " subset_scl = ds_chunked_L2A.sel(band=\"scl\")[:, y_start:y_end, x_start:x_end]\n", - " \n", + "\n", " # Count the cloud pixels in the subset\n", " cloud_pct = count_cloud_pixels(subset_scl, SCL_CLOUD_LABELS)\n", - " \n", + "\n", " # Store the cloud percent for this chunk\n", " cloud_pcts[(x, y)] = int(100 * (cloud_pct / 1024))\n", "\n", "\n", "# Print chunk bounds\n", - "#for key, value in chunk_bounds.items():\n", - " #print(f\"Chunk {key}: {value}\")\n", + "# for key, value in chunk_bounds.items():\n", + "# print(f\"Chunk {key}: {value}\")\n", "\n", "# Print indices where cloud percentages exceed some interesting threshold\n", "for key, value in cloud_pcts.items():\n", " if value > 50:\n", - " print(f\"Chunk {key}: Cloud percentage = {value}\")\n", - "\n" + " print(f\"Chunk {key}: Cloud percentage = {value}\")" ] }, { @@ -330,8 +332,10 @@ "outputs": [], "source": [ "DATA_DIR = \"data/minicubes_cloud\"\n", - "CKPT_PATH = (\"https://huggingface.co/made-with-clay/Clay/resolve/main/\"\n", - " \"Clay_v0.1_epoch-24_val-loss-0.46.ckpt\")\n", + "CKPT_PATH = (\n", + " \"https://huggingface.co/made-with-clay/Clay/resolve/main/\"\n", + " \"Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n", + ")\n", "# Load model\n", "multi_model = CLAYModule.load_from_checkpoint(\n", " CKPT_PATH,\n", @@ -401,9 +405,9 @@ "metadata": {}, "outputs": [], "source": [ - "print(len(embeddings[0])) # embeddings is a list\n", - "print(embeddings[0].shape) # with date and lat/lon\n", - "print(embeddings[0][:, :-2, :].shape) # remove date and lat/lon" + "print(len(embeddings[0])) # embeddings is a list\n", + "print(embeddings[0].shape) # with date and lat/lon\n", + "print(embeddings[0][:, :-2, :].shape) # remove date and lat/lon" ] }, { @@ -414,7 +418,7 @@ "outputs": [], "source": [ "# remove date and lat/lon and reshape to disaggregated patches\n", - "embeddings_patch = embeddings[0][:, :-2, :].reshape([1,16,16,768]) " + "embeddings_patch = embeddings[0][:, :-2, :].reshape([1, 16, 16, 768])" ] }, { @@ -483,36 +487,43 @@ " embeddings_output_patch = embeddings_patch_avg_group[i, j]\n", "\n", " item_ = [\n", - " element for element in list(chunk_bounds.items()) if element[0] == (i,j)\n", + " element for element in list(chunk_bounds.items()) if element[0] == (i, j)\n", + " ]\n", + " box_ = [\n", + " item_[0][1][\"lon_start\"],\n", + " item_[0][1][\"lat_start\"],\n", + " item_[0][1][\"lon_end\"],\n", + " item_[0][1][\"lat_end\"],\n", " ]\n", - " box_ = [item_[0][1]['lon_start'], item_[0][1]['lat_start'],\n", - " item_[0][1]['lon_end'], item_[0][1]['lat_end']]\n", " cloud_pct_ = [\n", - " element for element in list(cloud_pcts.items()) if element[0] == (i,j)\n", + " element for element in list(cloud_pcts.items()) if element[0] == (i, j)\n", " ]\n", " source_url = batch[\"source_url\"]\n", " date = batch[\"date\"]\n", " data = {\n", " \"source_url\": batch[\"source_url\"][0],\n", - " \"date\": pd.to_datetime(arg=date, \n", - " format=\"%Y-%m-%d\").astype(dtype=\"date32[day][pyarrow]\"),\n", + " \"date\": pd.to_datetime(arg=date, format=\"%Y-%m-%d\").astype(\n", + " dtype=\"date32[day][pyarrow]\"\n", + " ),\n", " \"embeddings\": [numpy.ascontiguousarray(embeddings_output_patch)],\n", - " \"cloud_cover\": cloud_pct_[0][1]\n", + " \"cloud_cover\": cloud_pct_[0][1],\n", " }\n", - " \n", + "\n", " # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)\n", - " # The box_ list is encoded as \n", + " # The box_ list is encoded as\n", " # [bottom left x, bottom left y, top right x, top right y]\n", " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n", - " \n", + "\n", " # Create the GeoDataFrame\n", " gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{epsg}\")\n", - " \n", + "\n", " # Reproject to WGS84 (lon/lat coordinates)\n", " gdf = gdf.to_crs(epsg=4326)\n", "\n", - " outpath = (f\"{outdir_embeddings}/\"\n", - " f\"{batch['source_url'][0].split('/')[-1][:-4]}_{i}_{j}.gpq\")\n", + " outpath = (\n", + " f\"{outdir_embeddings}/\"\n", + " f\"{batch['source_url'][0].split('/')[-1][:-4]}_{i}_{j}.gpq\"\n", + " )\n", " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n", " print(\n", " f\"Saved {len(gdf)} rows of embeddings of \"\n", @@ -553,23 +564,26 @@ "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n", " gdf = gpd.read_parquet(emb)\n", " gdf[\"year\"] = gdf.date.dt.year\n", - " gdf[\"tile\"] = gdf[\"source_url\"].apply(lambda x: Path(x)\n", - " .stem.rsplit(\"/\")[-1].rsplit(\"_\")[0])\n", - " gdf[\"idx\"] = '_'.join(emb.split(\"/\")[-1].split(\"_\")[2:]).replace('.gpq', '')\n", + " gdf[\"tile\"] = gdf[\"source_url\"].apply(\n", + " lambda x: Path(x).stem.rsplit(\"/\")[-1].rsplit(\"_\")[0]\n", + " )\n", + " gdf[\"idx\"] = \"_\".join(emb.split(\"/\")[-1].split(\"_\")[2:]).replace(\".gpq\", \"\")\n", " gdf[\"box\"] = [box(*geom.bounds) for geom in gdf.geometry]\n", " gdfs.append(gdf)\n", - " \n", - " for _,row in gdf.iterrows():\n", - " data.append({\n", - " \"vector\": row[\"embeddings\"],\n", - " \"path\": row[\"source_url\"],\n", - " \"tile\": row[\"tile\"],\n", - " \"date\": row[\"date\"],\n", - " \"year\": int(row[\"year\"]),\n", - " \"cloud_cover\": row[\"cloud_cover\"],\n", - " \"idx\": row[\"idx\"],\n", - " \"box\": row[\"box\"].bounds,\n", - " })" + "\n", + " for _, row in gdf.iterrows():\n", + " data.append(\n", + " {\n", + " \"vector\": row[\"embeddings\"],\n", + " \"path\": row[\"source_url\"],\n", + " \"tile\": row[\"tile\"],\n", + " \"date\": row[\"date\"],\n", + " \"year\": int(row[\"year\"]),\n", + " \"cloud_cover\": row[\"cloud_cover\"],\n", + " \"idx\": row[\"idx\"],\n", + " \"box\": row[\"box\"].bounds,\n", + " }\n", + " )" ] }, { @@ -616,8 +630,9 @@ "epsg = items_L2A[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point from lon/lat to UTM projection\n", - "box_embedding = gpd.GeoDataFrame(crs=\"OGC:CRS84\", \n", - " geometry=[area_of_interest_embedding]).to_crs(epsg)\n", + "box_embedding = gpd.GeoDataFrame(\n", + " crs=\"OGC:CRS84\", geometry=[area_of_interest_embedding]\n", + ").to_crs(epsg)\n", "geom_embedding = box_embedding.iloc[0].geometry\n", "\n", "# Create bounds of the correct size, the model\n", @@ -637,7 +652,7 @@ " fill_value=0,\n", " assets=BAND_GROUPS_L2A[\"rgb\"],\n", " resampling=Resampling.nearest,\n", - " xy_coords='center',\n", + " xy_coords=\"center\",\n", ")\n", "\n", "stack_embedding = stack_embedding.compute()\n", @@ -694,7 +709,7 @@ "source": [ "# Function to get the average of some list of reference vectors\n", "def get_average_vector(idxs):\n", - " reformatted_idxs = ['_'.join(map(str, idx)) for idx in idxs]\n", + " reformatted_idxs = [\"_\".join(map(str, idx)) for idx in idxs]\n", " matching_rows = [\n", " tbl.to_pandas().query(f\"idx == '{idx}'\") for idx in reformatted_idxs\n", " ]\n", @@ -718,7 +733,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Get indices for patches where cloud percentages \n", + "# Get indices for patches where cloud percentages\n", "# exceed some interesting threshold\n", "cloudy_indices = []\n", "for key, value in cloud_pcts.items():\n", @@ -764,12 +779,12 @@ }, "outputs": [], "source": [ - "# Get indices for patches where cloud percentages \n", + "# Get indices for patches where cloud percentages\n", "# do not exceed some interesting threshold\n", "non_cloudy_indices = []\n", "for key, value in cloud_pcts.items():\n", " if value < 10:\n", - " #print(f\"Chunk {key}: Cloud percentage = {value}\")\n", + " # print(f\"Chunk {key}: Cloud percentage = {value}\")\n", " non_cloudy_indices.append(key)" ] }, @@ -822,38 +837,38 @@ " # Define the window size\n", " window_size = (32, 32)\n", "\n", - " idxs_windows = {'idx': [], 'window': []}\n", + " idxs_windows = {\"idx\": [], \"window\": []}\n", "\n", " # Iterate over the image in 32x32 windows\n", " for col in range(0, width, window_size[0]):\n", " for row in range(0, height, window_size[1]):\n", " # Define the window\n", " window = ((row, row + window_size[1]), (col, col + window_size[0]))\n", - " \n", + "\n", " # Read the data within the window\n", " data = chip.read(window=window)\n", - " \n", + "\n", " # Get the index of the window\n", " index = (col // window_size[0], row // window_size[1])\n", - " \n", + "\n", " # Process the window data here\n", " # For example, print the index and the shape of the window data\n", - " #print(\"Index:\", index)\n", - " #print(\"Window Shape:\", data.shape)\n", - "\n", - " idxs_windows['idx'].append('_'.join(map(str, index)))\n", - " idxs_windows['window'].append(data)\n", - " \n", - " #print(idxs_windows)\n", - " \n", + " # print(\"Index:\", index)\n", + " # print(\"Window Shape:\", data.shape)\n", + "\n", + " idxs_windows[\"idx\"].append(\"_\".join(map(str, index)))\n", + " idxs_windows[\"window\"].append(data)\n", + "\n", + " # print(idxs_windows)\n", + "\n", " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n", " idx = row[\"idx\"]\n", " # Find the corresponding window based on the idx\n", - " window_index = idxs_windows['idx'].index(idx)\n", - " window_data = idxs_windows['window'][window_index]\n", - " #print(window_data.shape)\n", + " window_index = idxs_windows[\"idx\"].index(idx)\n", + " window_data = idxs_windows[\"window\"][window_index]\n", + " # print(window_data.shape)\n", " subset_img = numpy.clip(\n", - " (window_data.transpose(1,2,0)[:, :, :3]/10_000) * 3, 0,1\n", + " (window_data.transpose(1, 2, 0)[:, :, :3] / 10_000) * 3, 0, 1\n", " )\n", " ax.imshow(subset_img)\n", " ax.set_title(f\"{tile}/{idx}/{row.cloud_cover}\")\n", @@ -916,18 +931,23 @@ "# Make geodataframe of the search results\n", "# cloudy\n", "result_cloudy_boxes = [\n", - " Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), \n", - " (bbox[2], bbox[3]), (bbox[0], bbox[3])]) \n", - " for bbox in result_cloudy['box']]\n", + " Polygon(\n", + " [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]\n", + " )\n", + " for bbox in result_cloudy[\"box\"]\n", + "]\n", "result_cloudy_gdf = gpd.GeoDataFrame(result_cloudy, geometry=result_cloudy_boxes)\n", "result_cloudy_gdf.crs = \"EPSG:4326\"\n", "# non-cloudy\n", "result_non_cloudy_boxes = [\n", - " Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), \n", - " (bbox[2], bbox[3]), (bbox[0], bbox[3])]) \n", - " for bbox in result_non_cloudy['box']]\n", + " Polygon(\n", + " [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]\n", + " )\n", + " for bbox in result_non_cloudy[\"box\"]\n", + "]\n", "result_non_cloudy_gdf = gpd.GeoDataFrame(\n", - " result_non_cloudy, geometry=result_non_cloudy_boxes)\n", + " result_non_cloudy, geometry=result_non_cloudy_boxes\n", + ")\n", "result_non_cloudy_gdf.crs = \"EPSG:4326\"\n", "\n", "# Plot the AOI in RGB\n", @@ -935,14 +955,14 @@ "plot.imshow(row=\"time\", rgb=\"band\", vmin=0, vmax=2000)\n", "\n", "# Overlay the bounding boxes of the patches identified from the similarity search\n", - "result_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color='red', alpha=0.5)\n", - "result_non_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color='blue', alpha=0.5)\n", + "result_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color=\"red\", alpha=0.5)\n", + "result_non_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color=\"blue\", alpha=0.5)\n", "\n", "\n", "# Set plot title and labels\n", - "plt.title('Sentinel-2 with cloudy and non-cloudy embeddings')\n", - "plt.xlabel('Longitude')\n", - "plt.ylabel('Latitude')\n", + "plt.title(\"Sentinel-2 with cloudy and non-cloudy embeddings\")\n", + "plt.xlabel(\"Longitude\")\n", + "plt.ylabel(\"Latitude\")\n", "\n", "# Show the plot\n", "plt.show()"