Skip to content

Commit

Permalink
Merge pull request #32 from nasaharvest/update-2
Browse files Browse the repository at this point in the history
Many updates
  • Loading branch information
gabrieltseng committed Jan 31, 2024
2 parents 632cd39 + fe97099 commit 572be39
Show file tree
Hide file tree
Showing 15 changed files with 1,500 additions and 738 deletions.
113 changes: 58 additions & 55 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from presto.eval import (
AlgaeBloomsEval,
CropHarvestEval,
CroptypeFranceEval,
EuroSatEval,
EvalTask,
FuelMoistureEval,
Expand Down Expand Up @@ -52,13 +53,6 @@
help="Output is stored in <data_dir>/output. "
"Leave empty to use the directory you are running this file from.",
)
argparser.add_argument(
"--eval_seeds",
type=int,
default=[0, DEFAULT_SEED, 48],
nargs="+",
help="seeds to use for eval tasks",
)
argparser.add_argument("--fully_supervised", dest="fully_supervised", action="store_true")
argparser.add_argument("--wandb", dest="wandb", action="store_true")
argparser.set_defaults(wandb=False)
Expand All @@ -68,7 +62,6 @@
path_to_state_dict = args["path_to_state_dict"]
path_to_config = args["path_to_config"]
fully_supervised = args["fully_supervised"]
eval_seeds = args["eval_seeds"]
wandb_enabled = args["wandb"]
data_dir = args["data_dir"]
if data_dir != "":
Expand All @@ -93,62 +86,66 @@

if path_to_config == "":
path_to_config = config_dir / "default.json"
logger.info("Loading config from %s" % path_to_config)
model_kwargs = json.load(Path(path_to_config).open("r"))
model = Presto.construct(**model_kwargs)

if not fully_supervised:
if path_to_state_dict == "":
path_to_state_dict = default_model_path
logger.info("Loading params from %s" % path_to_state_dict)
model.load_state_dict(torch.load(path_to_state_dict, map_location=device))
model.to(device)

logger.info("Loading evaluation tasks")
seeds = [0, DEFAULT_SEED, 84]
eval_task_list: List[EvalTask] = [
*[CropHarvestEval("Kenya", ignore_dynamic_world=True, seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Togo", ignore_dynamic_world=True, seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Brazil", ignore_dynamic_world=True, seeds=[s]) for s in eval_seeds],
*[FuelMoistureEval(seeds=[s]) for s in eval_seeds],
*[AlgaeBloomsEval(seeds=[s]) for s in eval_seeds],
# no seeds for EuroSat, which we evaluate using
# a KNN classifier
EuroSatEval(rgb=True, input_patch_size=32),
EuroSatEval(rgb=True, input_patch_size=16),
EuroSatEval(rgb=True, input_patch_size=8),
EuroSatEval(rgb=True, input_patch_size=4),
EuroSatEval(rgb=True, input_patch_size=2),
EuroSatEval(rgb=True, input_patch_size=1),
EuroSatEval(rgb=False, input_patch_size=32),
EuroSatEval(rgb=False, input_patch_size=16),
EuroSatEval(rgb=False, input_patch_size=8),
EuroSatEval(rgb=False, input_patch_size=4),
EuroSatEval(rgb=False, input_patch_size=2),
EuroSatEval(rgb=False, input_patch_size=1),
TreeSatEval("S1", input_patch_size=1, seeds=eval_seeds),
TreeSatEval("S2", input_patch_size=1, seeds=eval_seeds),
TreeSatEval("S1", input_patch_size=2, seeds=eval_seeds),
TreeSatEval("S2", input_patch_size=2, seeds=eval_seeds),
TreeSatEval("S1", input_patch_size=3, seeds=eval_seeds),
TreeSatEval("S2", input_patch_size=3, seeds=eval_seeds),
TreeSatEval("S1", input_patch_size=6, seeds=eval_seeds),
TreeSatEval("S2", input_patch_size=6, seeds=eval_seeds),
*[CropHarvestEval("Kenya", seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Togo", seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Brazil", seeds=[s]) for s in eval_seeds],
*[
CropHarvestEval(country="Brazil", ignore_dynamic_world=idw, seed=seed)
for idw in [True, False]
for seed in seeds
],
*[
CropHarvestEval(country="Kenya", ignore_dynamic_world=idw, seed=seed, sample_size=s)
for idw in [True, False]
for seed in seeds
for s in CropHarvestEval.country_to_sizes["Kenya"]
],
*[
CropHarvestEval(country="Togo", ignore_dynamic_world=idw, seed=seed, sample_size=s)
for idw in [True, False]
for seed in seeds
for s in CropHarvestEval.country_to_sizes["Togo"]
],
*[FuelMoistureEval(seed=seed) for seed in seeds],
*[AlgaeBloomsEval(seed=seed) for seed in seeds],
*[
EuroSatEval(rgb=rgb, input_patch_size=ps, seed=seed, aggregates=["mean"])
for rgb in [True, False]
for ps in [1, 2, 4, 8, 16, 32, 64]
for seed in seeds
],
*[
TreeSatEval(subset=subset, seed=seed, aggregates=["mean"])
for subset in ["S1", "S2"]
for seed in seeds
],
*[
CropHarvestEval("Togo", ignore_dynamic_world=True, num_timesteps=x, seed=seed)
for x in range(1, 12)
for seed in seeds
],
*[
CropHarvestEval("Kenya", ignore_dynamic_world=True, num_timesteps=x, seed=seed)
for x in range(1, 12)
for seed in seeds
],
*[
CroptypeFranceEval(input_patch_size=patch_size, aggregates=["mean"], seed=seed)
for patch_size in [1, 5]
for seed in seeds
],
]
# add CropHarvest over time
for seed in eval_seeds:
eval_task_list.extend(
[
CropHarvestEval("Togo", ignore_dynamic_world=True, num_timesteps=x, seeds=[seed])
for x in range(1, 12)
]
)
eval_task_list.extend(
[
CropHarvestEval("Kenya", ignore_dynamic_world=True, num_timesteps=x, seeds=[seed])
for x in range(1, 12)
]
)

if wandb_enabled:
eval_config = {
Expand All @@ -157,7 +154,6 @@
"decoder": model.decoder.__class__,
"device": device,
"model_parameters": "random" if fully_supervised else path_to_state_dict,
"logging_dir": logging_dir,
**args,
**model_kwargs,
}
Expand All @@ -167,9 +163,16 @@
for eval_task in tqdm(eval_task_list, desc="Full Evaluation"):
model_modes = ["finetune", "Regression", "Random Forest"]
if "EuroSat" in eval_task.name:
model_modes = ["Regression", "Random Forest", "KNNat5", "KNNat20", "KNNat100"]
model_modes = [
"Regression",
"Random Forest",
"KNNat5",
"KNNat20",
"KNNat100",
"finetune",
]
if "TreeSat" in eval_task.name:
model_modes = ["Random Forest"]
model_modes = ["finetune", "Random Forest"]
logger.info(eval_task.name)

results = eval_task.finetuning_results(model, model_modes=model_modes)
Expand Down
19 changes: 6 additions & 13 deletions mosaiks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
argparser.add_argument("--k", type=int, default=8192)
argparser.add_argument("--kernel_size", type=int, default=3)
argparser.add_argument("--wandb", dest="wandb", action="store_true")
argparser.add_argument(
"--eval_seeds",
type=int,
default=[0, DEFAULT_SEED, 48],
nargs="+",
help="seeds to use for eval tasks",
)
argparser.add_argument("--wandb_plots", type=int, default=3)

argparser.add_argument(
"--train_url",
Expand All @@ -46,7 +40,6 @@
k = args["k"]
kernel_size = args["kernel_size"]
wandb_enabled: bool = args["wandb"]
eval_seeds = args["eval_seeds"]

train_url: str = args["train_url"]

Expand Down Expand Up @@ -101,11 +94,11 @@ def load_dataset(url, shuffle_on_load):

logger.info("Loading evaluation tasks")
eval_task_list: List[EvalTask] = [
*[FuelMoistureEval(seeds=[s]) for s in eval_seeds],
*[AlgaeBloomsEval(seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Kenya", seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Togo", seeds=[s]) for s in eval_seeds],
*[CropHarvestEval("Brazil", seeds=[s]) for s in eval_seeds],
FuelMoistureEval(seed=seed),
AlgaeBloomsEval(seed=seed),
CropHarvestEval("Kenya", seed=seed),
CropHarvestEval("Togo", seed=seed),
CropHarvestEval("Brazil", seed=seed),
]

for eval_task in tqdm(eval_task_list, desc="Full Evaluation"):
Expand Down
2 changes: 2 additions & 0 deletions presto/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .algae_blooms_eval import AlgaeBloomsEval
from .cropharvest_eval import CropHarvestEval, CropHarvestMultiClassValidation
from .croptype_france_eval import CroptypeFranceEval
from .eurosat_eval import EuroSatEval
from .eval import EvalTask
from .fuel_moisture_eval import FuelMoistureEval
Expand All @@ -8,6 +9,7 @@
__all__ = [
"CropHarvestEval",
"CropHarvestMultiClassValidation",
"CroptypeFranceEval",
"EvalTask",
"EuroSatEval",
"FuelMoistureEval",
Expand Down
13 changes: 7 additions & 6 deletions presto/eval/algae_blooms_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from .. import utils
from ..dataops import S1_S2_ERA5_SRTM, TAR_BUCKET
from ..model import FineTuningModel, Mosaiks1d, Seq2Seq
from ..presto import param_groups_lrd
from ..utils import DEFAULT_SEED, device
from .cropharvest_extensions import DynamicWorldExporter, Engineer
from .eval import EvalTask, Hyperparams
Expand All @@ -40,9 +39,9 @@ class AlgaeBloomsEval(EvalTask):
multilabel = False
num_outputs = 1

def __init__(self, seeds: List[int] = [DEFAULT_SEED]) -> None:
def __init__(self, seed: int = DEFAULT_SEED) -> None:
self.labels = self.load_labels()
super().__init__(seeds)
super().__init__(seed)

@staticmethod
def load_labels():
Expand Down Expand Up @@ -212,12 +211,14 @@ def finetune(self, pretrained_model, mask: Optional[np.ndarray] = None) -> FineT
hyperparams = Hyperparams(max_epochs=200, patience=10, batch_size=64)
model = self._construct_finetuning_model(pretrained_model)

parameters = param_groups_lrd(model)
opt = AdamW(parameters, lr=hyperparams.lr)
opt = AdamW(model.parameters(), lr=hyperparams.lr, weight_decay=hyperparams.weight_decay)

def loss_fn(preds, target):
return nn.functional.huber_loss(preds.flatten(), target)

def val_loss_fn(preds, target):
return mean_squared_error(preds.cpu().numpy(), target.cpu().numpy())

x_np, dw_np, month_np, target_np, latlon_np = self.load_npys(test=False)

val_filter = (latlon_np[:, 0] > 42.5) & (latlon_np[:, 1] < -92.5)
Expand Down Expand Up @@ -252,7 +253,7 @@ def loss_fn(preds, target):
)

return self.finetune_pytorch_model(
model, hyperparams, opt, train_dl, val_dl, loss_fn, mean_squared_error, mask
model, hyperparams, opt, train_dl, val_dl, loss_fn, val_loss_fn, mask
)

def finetuning_results(
Expand Down
Loading

0 comments on commit 572be39

Please sign in to comment.