Skip to content

nasaharvest/presto

Repository files navigation

The Pretrained Remote Sensing Transformer (Presto)

This code accompanies our paper, Lightweight, Pre-trained Transformers for Remote Sensing Timeseries.

Environment Setup

python -m venv venv
source venv/bin/activate
pip install -e .

wandb can additionally be installed for full functionality of the train.py script.

Entrypoints

Three entrypoints to the code are available: train.py, eval.py and mosaiks.py.

In addition, a jupyter notebook is available demonstrating how Presto can be finetuned on different downstream tasks.

Finally, Presto can also be loaded directly from the python package. We also have included Presto contained in a single file (i.e. with no imports from elsewhere in the package) at single_file_presto.py, if you want to easily integrate it into a different application. We test that these models are equivalent:

# either import works. The single_file_presto has no load_pretrained function, since this
# requires knowing where the pretrained file is. The state dict can be loaded directly
# from data/default_models.pt
from single_file_presto import Presto
from presto import Presto

# to make a randomly initialized encoder-decoder model
encoder_decoder = Presto.construct()
# alternatively, the pre-trained model can also be loaded
encoder_decoder = Presto.load_pretrained()

# to isolate the encoder
encoder_only = encoder_decoder.encoder
# to add a linear transformation to the encoder's output for finetuning
finetuning_model = encoder_decoder.construct_finetuning_model(num_outputs=1, regression=True)

The default arguments to construct are the same as the default parameters described in default.json.

Presto expects the following values as input, and returns the following outputs:

reconstructed_x, reconstructed_dynamic_world = encoder_decoder(x, dynamic_world, latlons, mask, month)

globally_pooled_tokens = encoder(x, dynamic_world, latlons, mask, month, eval_task=True)

predictions = finetuning_model(x, dynamic_world, latlons, mask, month)
  • x: torch.Tensor of shape [batch_size, num_timesteps, bands] where bands is described by NORMED_BANDS.
  • dynamic_world: torch.Tensor of shape [batch_size, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.
  • latlons: torch.Tensor of shape [batch_size, 2] describing the latitude and longitude of each input instance.
  • mask: An optional torch.Tensor of shape [batch_size, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.
  • month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.

The number of timesteps passed is optional, and can be any value between 1 and 24 (2 years of data).

3 of the input tensors (x, dynamic_world, mask) can be generated using presto.construct_single_presto_input. An example of this is in the downstream task jupyter notebook. For example, if I have access to some RGB imagery, it can be turned into Presto-compatible inputs:

import presto
x, mask, dynamic_world = presto.construct_single_presto_input(
    s2=rgb_imagery,  # of shape [num_timesteps, 3]
    s2_bands=["B2", "B3", "B4"]
)

Here, x will contain only the (normalized) RGB values in the correct indices, and mask will communicate to Presto to ignore every other input. Similarly, dynamic_world will contain only DynamicWorld2020_2021.class_amount, so Presto will ignore it.

Training

The train.py script contains code for self-supervised training. This can be run locally on a small subset of the data with:

# Barebones local run
python train.py \
    --train_url "data/dw_144_mini_shard_44.tar" \
    --val_url "data/dw_144_mini_shard_44.tar" \
    --val_per_n_steps 1 \
    --cropharvest_per_n_validations 0 \
    --skip_finetuning

Evaluation

A trained model (or a randomly initialized model) can be run against the evaluation tasks using eval.py. If an --id and --epoch is passed to the script, a model will be loaded from models/{id}/{epoch}.pt - otherwise, a randomly initialized model will be evaluated.

Mosaiks

The MOSAIKS1D benchmark can be run against evaluation tasks using the mosaiks.py script.

Generating new data

Diagram: url

Prerequisites:

  • Account with Google Cloud access and Earth Engine access
    curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-387.0.0-linux-x86_64.tar.gz
    tar -xf google-cloud-cli-387.0.0-linux-x86_64.tar.gz
    exec bash
    ./google-cloud-sdk/install.sh
    gcloud init
    earthengine authenticate
  • Create buckets for processing
    gcloud storage mb -l us-central1 $(python -c "from dataops import EE_BUCKET; print(EE_BUCKET)")
    gcloud storage mb -l us-central1 $(python -c "from dataops import NPY_BUCKET; print(NPY_BUCKET)")
    gcloud storage mb -l us-central1 $(python -c "from dataops import TAR_BUCKET; print(TAR_BUCKET)")
  • Deploy tif-to-np Cloud Function
    sh scripts/deploy_tif_to_np.sh

Once prerequisites are satisfied, data can be generated by running:

python scripts/generate_data.py

⚠️ This script assumes you have a Google Cloud project named presto - you need to change this in the script if the name of the project is different. ⚠️

The script will generate:

  • data/tile_processing.txt A summary of tiles being processed
  • data/tile_stats.yaml Stats for all tiles available for training

Behind the scenes for each tile the script will:

  1. Begin Earth Engine exports to get data for tile from specific data pipeline. Examples:
    • gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_1>/*.tif
    • gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_2>/*.tif
    • gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_3>/*.tif
  2. Process each tif file to npy. Examples:
    • gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_1>/*.npy
    • gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_2>/*.npy
    • gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_3>/*.npy
  3. Combine all npy files into a tar file accessible through webdataset. Example:
    • gs://<TAR_BUCKET>/<DATASET_NAME>/<SHARD_1>.tar

Accessing new data

In [0]:
import webdataset as wds
import pandas as pd
uri = "gs://lem-assets2/S1_S2_ERA5_SRTM_2020_2021_DynamicWorld2020_2021_tars/dw_144_shard_0.tar"
dataset = wds.WebDataset(f"pipe:gcloud storage cat {uri}").decode()
for sample in dataset:
    break

In [1]: list(sample.keys())
Out[1]: ['__key__', '__url__', 'dynamicworld2020_2021.npy', 's1_s2_era5_srtm_2020_2021.npy', 'worldcover2020.npy']

In [2]: sample["s1_s2_era5_srtm_2020_2021.npy"].shape
Out[2]: (625, 24, 18)

In [3]: sample["latlon.npy"].shape
Out[3]: (625, 2)

In [4]: sample["worldcover2020.npy"].shape
Out[4]: (625, 1)

In [5]: sample["dynamicworld2020_2021.npy"].shape
Out[5]: (625, 24)

In [6]: pd.Series(sample["dynamicworld2020_2021.npy"].flatten()).value_counts()
Out[6]:
0    14978
7       22
dtype: int64

Reference

If you find this code useful, please cite the following paper:

@misc{tseng2023lightweight,
      title={Lightweight, Pre-trained Transformers for Remote Sensing Timeseries},
      author={Gabriel Tseng and Ruben Cartuyvels and Ivan Zvonkov and Mirali Purohit and David Rolnick and Hannah Kerner},
      year={2023},
      eprint={2304.14065},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}