Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big update #14

Merged
merged 10 commits into from
Sep 21, 2023
Merged

Big update #14

merged 10 commits into from
Sep 21, 2023

Conversation

gabrieltseng
Copy link
Collaborator

@gabrieltseng gabrieltseng commented Sep 21, 2023

This PR introduces a number of significant updates:

  • Model updates: Presto previously considered SRTM as dynamic in time (num_timesteps SRTM tokens were passed through the model), even though it was only collected at a single timestep. A single SRTM token is now passed through the model. In addition, a bug in the decoder which shuffled some bands during reconstruction has been fixed.
  • Masking updates: The masking functions were updated to work with both dynamic-in-time and static-in-time inputs. In addition, they were updated to accept arbitrary masking ratios (previously, the masking ratio had to yield a number of masked tokens divisible by the number of timesteps and the number of band groups).

The default_model.pt weights have now been updated with a model trained after these changes.

  • Evaluation updates (images): For evaluation tasks which make a single prediction for an image, we now pass a mean and standard deviation of the pixel-outputs for that image to a downstream classifier (instead of passing each individual pixel to the classifier and taking the mode of the classifier's outputs).
  • Evaluation updates (seeds): All downstream classifiers with randomness (i.e. everything except the K-nearest-neighbours classifier and the linear regression) are run for 3 seeds to measure Presto's sensitivity to seeds.
  • Evaluation updates (EuroSat): Our EuroSat evaluation now better mirrors the approach taken by previous works; we downsample images to different resolutions, and process all the pixels in the image. In addition, we use the splits described Update EuroSat eval task to use splits from https://arxiv.org/pdf/1911.06721.pdf #8).

A manuscript with updated results will be posted on arxiv The latest arxiv version of the paper reflects these changes.

  • Many more tests: We add many more tests to ensure our masking functions are doing what we expect, and to ensure tokens are being handled in the right way by the Presto encoder and decoder.
  • Other infrastructure updates: Lots of changes here: (i) we add logging to the training and evaluation code for clarity, (ii) speed up evaluation tasks by using multiple workers for the eval dataloaders, (iii) replace the processing code for TreeSat and EuroSat with a dataloader which reads the raw files (so there is no need to process the data before running th evaluation code), (iv) run CropHarvest validation less often to speed up pre-training, (iv) make the data_dir configurable in train.py, (vi) dump the evaluation results in a json file locally in addition to (optionally) storing them on wandb.

The repository before this PR is tagged v0.1.

name: str
num_outputs: int
regression: bool
multilabel: bool

def __init__(self, seed: int = DEFAULT_SEED):
self.seed = seed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like that this class has a seed: int argument and its subclass EvalTaskWithAggregatedOutput has a seeds: List[int] argument, so you basically have to use isinstance(EvalTaskWithAggregatedOutput,...) to know what field to use. I think it would be nicer to have this class use a field seeds: List[int] too and include a check this its length equals 1 (or something along these lines)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

"encoder": model.encoder.__class__,
"decoder": model.decoder.__class__,
"device": device,
"model_parameters": "random" if fully_supervised else path_to_state_dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should have added "logging_dir": logging_dir, here but I forgot, same in train.py. Because now output_dir gets logged but it's basically always something like /network/scratch/<u>/<user>/presto/ and then you don't know what the actual subdir for this run was

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added!

@@ -194,7 +244,7 @@ def load_dataset(url, shuffle_on_load):
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add "logging_dir": logging_dir,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

argparser.add_argument("--fully_supervised", dest="fully_supervised", action="store_true")
argparser.add_argument("--wandb", dest="wandb", action="store_true")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an argument argparser.add_argument("--nb_seeds", type=int, default=1) so it's actually also still possible to run the eval script with just 1 seed?
Or do we not want to support that?
Then below: seeds = list(range(0, DEFAULT_SEED*nb_seeds, DEFAULT_SEED) if nb_seeds > 1 else [DEFAULT_SEED]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added an eval_seeds arguments

eval.py Outdated
CropHarvestEval("Brazil", seed=0),
CropHarvestEval("Kenya", seed=84),
CropHarvestEval("Togo", seed=84),
CropHarvestEval("Brazil", seed=84),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then here

eval_task_list: List[EvalTask] = [
    *[CropHarvestEval("Kenya", seed=seed) for seed in seeds],
    *[CropHarvestEval("Kenya", ignore_dynamic_world=True, seed=seed) for seed in seeds],
    *[CropHarvestEval("Brazil", seed=seed) for seed in seeds],
    *[CropHarvestEval("Brazil", ignore_dynamic_world=True, seed=seed) for seed in seeds],
    *[CropHarvestEval("Togo", seed=seed) for seed in seeds],
    *[CropHarvestEval("Togo", ignore_dynamic_world=True, seed=seed) for seed in seeds],
    *[FuelMoistureEval(seed=seed) for seed in seeds],
    *[AlgaeBloomsEval(seed=seed) for seed in seeds],
    *[TreeSatEval("S1", input_patch_size=ps, seeds=seeds) for ps in [1, 2, 3, 6]],
    *[TreeSatEval("S2", input_patch_size=ps, seeds=seeds) for ps in [1, 2, 3, 6]],
    ... # EuroSat
]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this is great. updated.

Copy link
Collaborator

@rubencart rubencart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks good to me!

@gabrieltseng gabrieltseng merged commit e884f59 into main Sep 21, 2023
1 check passed
@gabrieltseng gabrieltseng deleted the update branch September 21, 2023 22:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants