-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
executable file
·122 lines (103 loc) · 3.58 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
"""Test Script
"""
from argparse import ArgumentParser
import pytorch_lightning as pl
from earthnet_models_pytorch.data import DATASETS
from earthnet_models_pytorch.model import MODELS
from earthnet_models_pytorch.task import SpatioTemporalTask
from earthnet_models_pytorch.utils import parse_setting
from pytorch_lightning.callbacks import TQDMProgressBar
def test_model(setting_dict: dict, checkpoint: str):
# Data
data_args = [
"--{}={}".format(key, value) for key, value in setting_dict["Data"].items()
]
data_parser = ArgumentParser()
data_parser = DATASETS[setting_dict["Setting"]].add_data_specific_args(data_parser)
data_params = data_parser.parse_args(data_args)
dm = DATASETS[setting_dict["Setting"]](data_params)
# Model
model_args = [
"--{}={}".format(key, value) for key, value in setting_dict["Model"].items()
]
model_parser = ArgumentParser()
model_parser = MODELS[setting_dict["Architecture"]].add_model_specific_args(
model_parser
)
model_params = model_parser.parse_args(model_args)
model = MODELS[setting_dict["Architecture"]](model_params)
# Task
task_args = [
"--{}={}".format(key, value) for key, value in setting_dict["Task"].items()
]
task_parser = ArgumentParser()
task_parser = SpatioTemporalTask.add_task_specific_args(task_parser)
task_params = task_parser.parse_args(task_args)
task = SpatioTemporalTask(model=model, hparams=task_params)
if checkpoint != "None":
task.load_from_checkpoint(
checkpoint_path=checkpoint,
context_length=setting_dict["Task"]["context_length"],
target_length=setting_dict["Task"]["target_length"],
model=model,
hparams=task_params,
)
# Trainer
trainer_dict = setting_dict["Trainer"]
trainer_dict["logger"] = False
trainer = pl.Trainer(callbacks=TQDMProgressBar(refresh_rate=10), **trainer_dict)
dm.setup("test")
trainer.test(model=task, datamodule=dm, ckpt_path=None)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"setting",
type=str,
metavar="path/to/setting.yaml",
help="yaml with all settings",
)
parser.add_argument(
"checkpoint", type=str, metavar="path/to/checkpoint", help="checkpoint file"
)
parser.add_argument(
"--track",
type=str,
metavar="iid|ood|ex|sea",
default="ood-t_chopped",
help="which track to test: either iid, ood, ex or sea",
)
parser.add_argument(
"--pred_dir",
type=str,
default="preds/",
metavar="path/to/prediction/dir",
help="Path where to save predictions",
)
parser.add_argument(
"--data_dir",
type=str,
default="data/greenearthnet/",
metavar="path/to/dataset",
help="Path where dataset is located",
)
parser.add_argument(
"--gpus",
type=int,
metavar="n gpus",
default=1,
help="how many gpus to use",
)
args = parser.parse_args()
import os
for k, v in os.environ.items():
if k.startswith("SLURM"):
del os.environ[k]
setting_dict = parse_setting(args.setting, track=args.track)
if args.pred_dir is not None:
setting_dict["Task"]["pred_dir"] = args.pred_dir
if args.data_dir is not None:
setting_dict["Data"]["base_dir"] = args.data_dir
if "gpus" in setting_dict["Trainer"]:
setting_dict["Trainer"]["gpus"] = args.gpus
test_model(setting_dict, args.checkpoint)