diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 35fb3204..3abb2d4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,9 +13,9 @@ repos: hooks: - id: black language_version: python3.9 - exclude: ^src/ml_downscaling_emulator/score_sde_pytorch/ + exclude: ^src/ml_downscaling_emulator/(run_lib.py|sde_lib.py|likelihood.py|sampling.py|losses.py|models|op|configs) - repo: https://github.com/pycqa/flake8 rev: '6.0.0' # pick a git hash / tag to point to hooks: - id: flake8 - exclude: ^src/ml_downscaling_emulator/score_sde_pytorch/ + exclude: ^src/ml_downscaling_emulator/(run_lib.py|sde_lib.py|likelihood.py|sampling.py|losses.py|models|op|configs) diff --git a/README.md b/README.md index dc68773e..7d71d03c 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Recommended to run with a sample of the dataset. Train models through `bin/main.py`, e.g. to train the model used in the paper use ```sh -python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERIVED_DATA}/path/to/models/paper-12em --mode train +python bin/main.py --config src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERIVED_DATA}/path/to/models/paper-12em --mode train ``` ```sh diff --git a/bin/bp/queue-training b/bin/bp/queue-training index ab8e219e..be1a1b3b 100755 --- a/bin/bp/queue-training +++ b/bin/bp/queue-training @@ -14,7 +14,7 @@ def train_cmd(sde, workdir, config, config_overrides=list): train_basecmd = ["python", f"bin/main.py"] train_opts = { - "--config": f"src/ml_downscaling_emulator/score_sde_pytorch/configs/{sde}/{config}.py", + "--config": f"src/ml_downscaling_emulator/configs/{sde}/{config}.py", "--workdir": workdir, "--mode": "train", } diff --git a/bin/deterministic/bp/queue-sampling b/bin/deterministic/bp/queue-sampling deleted file mode 100755 index fbf90515..00000000 --- a/bin/deterministic/bp/queue-sampling +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python -# setup jobs for sampling from a model - -import os -import subprocess -import sys - -import typer - -app = typer.Typer() - - -def sample_cmd( - dataset, epoch, samples_per_job, workdir, ensemble_member, input_transform_key=None -): - batch_size = 1024 - - sample_basecmd = ["mlde", "evaluate", "sample"] - - sample_opts = { - "--epoch": str(epoch), - "--num-samples": str(samples_per_job), - "--dataset": dataset, - "--batch-size": str(batch_size), - "--ensemble-member": str(ensemble_member), - } - - if input_transform_key is not None: - sample_opts["--input-transform-key"] = input_transform_key - - return ( - sample_basecmd - + [arg for item in sample_opts.items() for arg in item] - + [workdir] - ) - - -def queue_cmd(depends_on, sampling_jobs, sampling_duration): - queue_basecmd = ["lbatch"] - - queue_opts = { - "-a": os.getenv("HPC_PROJECT_CODE"), - "-g": "1", - "-m": "16", - "-q": "cnu,gpu", - "-t": str(sampling_duration), - "--condaenv": "cuda-downscaling", - "--array": f"1-{sampling_jobs}", - } - if depends_on is not None: - queue_opts["-d"] = str(depends_on) - - return queue_basecmd + [arg for item in queue_opts.items() for arg in item] - - -@app.command() -def main( - model_run_id: str, - cpm_dataset: str, - gcm_dataset: str, - epoch: int, - ensemble_member: str, - depends_on: int = None, - input_transform_key: str = None, -): - - sampling_jobs = 1 - samples_per_job = 3 - sampling_duration = 18 * samples_per_job - - workdir = f"{os.getenv('DERIVED_DATA')}/workdirs/u-net/{model_run_id}" - - shared_queue_cmd = queue_cmd(depends_on, sampling_jobs, sampling_duration) - - # sample CPM - full_cmd = ( - shared_queue_cmd - + ["--"] - + sample_cmd(cpm_dataset, epoch, samples_per_job, workdir, ensemble_member) - ) - print(" ".join(full_cmd).strip(), file=sys.stderr) - output = subprocess.run(full_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - # sample GCM - full_cmd = ( - shared_queue_cmd - + ["--"] - + sample_cmd(gcm_dataset, epoch, samples_per_job, workdir, input_transform_key) - ) - print(" ".join(full_cmd).strip(), file=sys.stderr) - output = subprocess.run(full_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - -if __name__ == "__main__": - app() diff --git a/bin/deterministic/bp/queue-training b/bin/deterministic/bp/queue-training deleted file mode 100755 index 184d1e84..00000000 --- a/bin/deterministic/bp/queue-training +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python -# setup jobs for training a model - -import os -import subprocess -import sys - -import typer - -app = typer.Typer() - - -def train_cmd(dataset, workdir, config_overrides=list): - train_basecmd = ["python", "bin/deterministic/main.py"] - - train_opts = { - "--config": "src/ml_downscaling_emulator/deterministic/configs/default.py", - "--workdir": workdir, - "--mode": "train", - } - - return ( - train_basecmd - + [arg for item in train_opts.items() for arg in item] - + [f"--config.data.dataset_name={dataset}"] - + config_overrides - ) - - -def queue_cmd(duration, memory): - queue_basecmd = ["lbatch"] - - queue_opts = { - "-a": os.getenv("HPC_PROJECT_CODE"), - "-g": "1", - "-m": str(memory), - "-q": "cnu,gpu", - "-t": str(duration), - "--condaenv": "cuda-downscaling", - } - - return queue_basecmd + [arg for item in queue_opts.items() for arg in item] - - -@app.command( - context_settings={ - "allow_extra_args": True, - "ignore_unknown_options": True, - } -) -def main( - ctx: typer.Context, - model_run_id: str, - cpm_dataset: str, - memory: int = 64, - duration: int = 72, -): - # Add any other config on the commandline for training - # --config.data.input_transform_key=spatial - - workdir = f"{os.getenv('DERIVED_DATA')}/workdirs/u-net/{model_run_id}" - - full_cmd = ( - queue_cmd(duration=duration, memory=memory) - + ["--"] - + train_cmd(cpm_dataset, workdir, ctx.args) - ) - print(" ".join(full_cmd).strip(), file=sys.stderr) - output = subprocess.run(full_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - -if __name__ == "__main__": - app() diff --git a/bin/deterministic/bp/train-sample b/bin/deterministic/bp/train-sample deleted file mode 100755 index fe450ad3..00000000 --- a/bin/deterministic/bp/train-sample +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python -# setup jobs for training and then sampling from a model - -import os -import subprocess -import sys - -import typer - -app = typer.Typer() - - -@app.command( - context_settings={ - "allow_extra_args": True, - "ignore_unknown_options": True, - } -) -def main( - ctx: typer.Context, - run_id: str = typer.Option(...), - cpm_dataset: str = typer.Option(...), - gcm_dataset: str = typer.Option(...), - epochs: int = typer.Option(...), -): - # Add any other config on the commandline for training - # --config.data.input_transform_key=spatial - - # train - train_cmd = ( - [f"{os.path.dirname(__file__)}/queue-training"] - + [run_id, cpm_dataset, "--epochs", str(epochs)] - + ctx.args - ) - print(" ".join(train_cmd).strip(), file=sys.stderr) - output = subprocess.run(train_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - training_job_id = output.stdout.decode("utf8").strip() - print(training_job_id) - - # sample - sample_cmd = [f"{os.path.dirname(__file__)}/queue-sampling"] + [ - run_id, - cpm_dataset, - gcm_dataset, - str(epochs), - "--depends-on", - training_job_id, - ] - print(" ".join(sample_cmd).strip(), file=sys.stderr) - output = subprocess.run(sample_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - -if __name__ == "__main__": - app() diff --git a/bin/deterministic/main.py b/bin/deterministic/main.py deleted file mode 100644 index 80153160..00000000 --- a/bin/deterministic/main.py +++ /dev/null @@ -1,43 +0,0 @@ -import ml_downscaling_emulator.deterministic.run_lib as run_lib -from absl import app -from absl import flags -from ml_collections.config_flags import config_flags -import logging -import os - -from knockknock import slack_sender - -FLAGS = flags.FLAGS - -config_flags.DEFINE_config_file( - "config", None, "Training configuration.", lock_config=True -) -flags.DEFINE_string("workdir", None, "Work directory.") -flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.") -flags.mark_flags_as_required(["workdir", "config", "mode"]) - - -@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") -def main(argv): - if FLAGS.mode == "train": - # Create the working directory - os.makedirs(FLAGS.workdir, exist_ok=True) - # Set logger so that it outputs to both console and file - # Make logging work for both disk and Google Cloud Storage - gfile_stream = open(os.path.join(FLAGS.workdir, "stdout.txt"), "w") - handler = logging.StreamHandler(gfile_stream) - formatter = logging.Formatter( - "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" - ) - handler.setFormatter(formatter) - logger = logging.getLogger() - logger.addHandler(handler) - logger.setLevel("INFO") - # Run the training pipeline - run_lib.train(FLAGS.config, FLAGS.workdir) - else: - raise ValueError(f"Mode {FLAGS.mode} not recognized.") - - -if __name__ == "__main__": - app.run(main) diff --git a/bin/deterministic/model-size b/bin/deterministic/model-size deleted file mode 100755 index ab955065..00000000 --- a/bin/deterministic/model-size +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python -# calculate the number of parameters in a deterministic model - -import logging -import os -from pathlib import Path - -from ml_collections import config_dict -from mlde_utils.training.dataset import get_variables -import torch -import typer -import yaml - - -from ml_downscaling_emulator.deterministic.utils import create_model -from ml_downscaling_emulator.utils import model_size, param_count - - -logger = logging.getLogger() -logger.setLevel("INFO") - -app = typer.Typer() - - -def load_config(config_path): - logger.info(f"Loading config from {config_path}") - with open(config_path) as f: - config = config_dict.ConfigDict(yaml.unsafe_load(f)) - - return config - - -def load_model(config): - num_predictors = len(get_variables(config.data.dataset_name)[0]) - if config.data.time_inputs: - num_predictors += 3 - model = torch.nn.DataParallel( - create_model(config, num_predictors).to(device=config.device) - ) - optimizer = torch.optim.Adam(model.parameters()) - state = dict(step=0, epoch=0, optimizer=optimizer, model=model) - - return state - - -@app.command() -def main( - workdir: Path, -): - config_path = os.path.join(workdir, "config.yml") - config = load_config(config_path) - model = load_model(config)["model"] - num_score_model_parameters = param_count(model) - - typer.echo(f"Model has {num_score_model_parameters} parameters") - - size_all_mb = model_size(model) - - typer.echo("model size: {:.3f}MB".format(size_all_mb)) - - -if __name__ == "__main__": - app() diff --git a/bin/main.py b/bin/main.py index ab44e086..d16c1129 100644 --- a/bin/main.py +++ b/bin/main.py @@ -16,7 +16,7 @@ """Training""" -import ml_downscaling_emulator.score_sde_pytorch.run_lib as run_lib +import ml_downscaling_emulator.run_lib as run_lib from absl import app from absl import flags from ml_collections.config_flags import config_flags diff --git a/bin/model-size b/bin/model-size index 1acb7efd..9a48d05e 100755 --- a/bin/model-size +++ b/bin/model-size @@ -10,20 +10,20 @@ import typer import logging import yaml -from ml_downscaling_emulator.score_sde_pytorch.models.location_params import ( +from ml_downscaling_emulator.models.location_params import ( LocationParams, ) -from ml_downscaling_emulator.score_sde_pytorch.models import utils as mutils +from ml_downscaling_emulator.models import utils as mutils -from ml_downscaling_emulator.score_sde_pytorch.models import cncsnpp # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import cunet # noqa: F401 +from ml_downscaling_emulator.models import cncsnpp # noqa: F401 +from ml_downscaling_emulator.models import cunet # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401 +from ml_downscaling_emulator.models import ( # noqa: F401 layerspp, # noqa: F401 ) # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import layers # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401 +from ml_downscaling_emulator.models import layers # noqa: F401 +from ml_downscaling_emulator.models import ( # noqa: F401 normalization, # noqa: F401 ) # noqa: F401 diff --git a/bin/predict.py b/bin/predict.py index ed6b6ba1..d7388823 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -22,33 +22,33 @@ from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER from mlde_utils.training.dataset import get_variables -from ml_downscaling_emulator.score_sde_pytorch.losses import get_optimizer -from ml_downscaling_emulator.score_sde_pytorch.models.ema import ( +from ml_downscaling_emulator.losses import get_optimizer +from ml_downscaling_emulator.models.ema import ( ExponentialMovingAverage, ) -from ml_downscaling_emulator.score_sde_pytorch.models.location_params import ( +from ml_downscaling_emulator.models.location_params import ( LocationParams, ) -from ml_downscaling_emulator.score_sde_pytorch.utils import restore_checkpoint +from ml_downscaling_emulator.utils import restore_checkpoint -import ml_downscaling_emulator.score_sde_pytorch.models as models # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import utils as mutils +import ml_downscaling_emulator.models as models # noqa: F401 +from ml_downscaling_emulator.models import utils as mutils -from ml_downscaling_emulator.score_sde_pytorch.models import cncsnpp # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import cunet # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import det_cunet # noqa: F401 +from ml_downscaling_emulator.models import cncsnpp # noqa: F401 +from ml_downscaling_emulator.models import cunet # noqa: F401 +from ml_downscaling_emulator.models import det_cunet # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401 +from ml_downscaling_emulator.models import ( # noqa: F401 layerspp, # noqa: F401 ) # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import layers # noqa: F401 -from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401 +from ml_downscaling_emulator.models import layers # noqa: F401 +from ml_downscaling_emulator.models import ( # noqa: F401 normalization, # noqa: F401 ) # noqa: F401 -import ml_downscaling_emulator.score_sde_pytorch.sampling as sampling +import ml_downscaling_emulator.sampling as sampling -from ml_downscaling_emulator.score_sde_pytorch.sde_lib import ( +from ml_downscaling_emulator.sde_lib import ( VESDE, VPSDE, subVPSDE, diff --git a/pyproject.toml b/pyproject.toml index 00867574..887570ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,4 +29,4 @@ dynamic = ["dependencies"] dependencies = { file = ["requirements.txt"] } [tool.black] -extend-exclude = '^/src/ml_downscaling_emulator/score_sde_pytorch/' +extend-exclude = '^/src/ml_downscaling_emulator/(run_lib.py|sde_lib.py|likelihood.py|sampling.py|losses.py|models|op|configs)' diff --git a/src/ml_downscaling_emulator/bin/__init__.py b/src/ml_downscaling_emulator/bin/__init__.py index dba70cbc..7fd0be43 100644 --- a/src/ml_downscaling_emulator/bin/__init__.py +++ b/src/ml_downscaling_emulator/bin/__init__.py @@ -1,9 +1,9 @@ import typer -from . import evaluate, postprocess +from . import postprocess, sample app = typer.Typer() -app.add_typer(evaluate.app, name="evaluate") +app.add_typer(sample.app, name="sample") app.add_typer(postprocess.app, name="postprocess") diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py deleted file mode 100644 index 5146ddbd..00000000 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ /dev/null @@ -1,152 +0,0 @@ -from codetiming import Timer -import logging -from knockknock import slack_sender -from ml_collections import config_dict -import os -from pathlib import Path -import shortuuid -import torch -import typer -import yaml - -from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER -from mlde_utils.training.dataset import load_raw_dataset_split -from ..deterministic import sampling -from ..deterministic.utils import create_model, restore_checkpoint -from ..data import get_dataloader - - -logging.basicConfig( - level=logging.INFO, - format="%(levelname)s - %(filename)s - %(asctime)s - %(message)s", -) -logger = logging.getLogger() -logger.setLevel("INFO") - -app = typer.Typer() - - -@app.callback() -def callback(): - pass - - -def load_config(config_path): - logger.info(f"Loading config from {config_path}") - with open(config_path) as f: - config = config_dict.ConfigDict(yaml.unsafe_load(f)) - - return config - - -def load_model(config, num_predictors, ckpt_filename): - model = torch.nn.DataParallel( - create_model(config, num_predictors).to(device=config.device) - ) - optimizer = torch.optim.Adam(model.parameters()) - state = dict(step=0, epoch=0, optimizer=optimizer, model=model) - state, loaded = restore_checkpoint(ckpt_filename, state, config.device) - assert loaded, "Did not load state from checkpoint" - - return state - - -@app.command() -@Timer(name="sample", text="{name}: {minutes:.1f} minutes", logger=logging.info) -@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") -def sample( - workdir: Path, - dataset: str = typer.Option(...), - split: str = "val", - checkpoint: str = typer.Option(...), - batch_size: int = None, - num_samples: int = 1, - input_transform_dataset: str = None, - input_transform_key: str = None, - ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, -): - - config_path = os.path.join(workdir, "config.yml") - config = load_config(config_path) - - if batch_size is not None: - config.eval.batch_size = batch_size - with config.unlocked(): - if input_transform_dataset is not None: - config.data.input_transform_dataset = input_transform_dataset - else: - config.data.input_transform_dataset = dataset - if input_transform_key is not None: - config.data.input_transform_key = input_transform_key - - output_dirpath = samples_path( - workdir=workdir, - checkpoint=checkpoint, - dataset=dataset, - input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", - split=split, - ensemble_member=ensemble_member, - ) - os.makedirs(output_dirpath, exist_ok=True) - - transform_dir = os.path.join(workdir, "transforms") - - eval_dl, _, target_transform = get_dataloader( - dataset, - config.data.dataset_name, - config.data.input_transform_dataset, - config.data.input_transform_key, - config.data.target_transform_key, - transform_dir, - split=split, - ensemble_members=[ensemble_member], - include_time_inputs=config.data.time_inputs, - evaluation=True, - batch_size=config.eval.batch_size, - shuffle=False, - ) - - ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth") - num_predictors = eval_dl.dataset[0][0].shape[0] - state = load_model(config, num_predictors, ckpt_filename) - - for sample_id in range(num_samples): - typer.echo(f"Sample run {sample_id}...") - xr_samples = sampling.sample(state["model"], eval_dl, target_transform) - - output_filepath = output_dirpath / f"predictions-{shortuuid.uuid()}.nc" - - logger.info(f"Saving predictions to {output_filepath}") - xr_samples.to_netcdf(output_filepath) - - -@app.command() -@Timer(name="sample", text="{name}: {minutes:.1f} minutes", logger=logging.info) -@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") -def sample_id( - workdir: Path, - dataset: str = typer.Option(...), - variable: str = "pr", - split: str = "val", - ensemble_member: str = "01", -): - - output_dirpath = samples_path( - workdir=workdir, - checkpoint=f"epoch-0", - dataset=dataset, - input_xfm="none", - split=split, - ensemble_member=ensemble_member, - ) - os.makedirs(output_dirpath, exist_ok=True) - - eval_ds = load_raw_dataset_split(dataset, split).sel( - ensemble_member=[ensemble_member] - ) - xr_samples = sampling.sample_id(variable, eval_ds) - - output_filepath = os.path.join(output_dirpath, f"predictions-{shortuuid.uuid()}.nc") - - logger.info(f"Saving predictions to {output_filepath}") - xr_samples.to_netcdf(output_filepath) diff --git a/src/ml_downscaling_emulator/bin/sample.py b/src/ml_downscaling_emulator/bin/sample.py new file mode 100644 index 00000000..1eea615c --- /dev/null +++ b/src/ml_downscaling_emulator/bin/sample.py @@ -0,0 +1,105 @@ +from codetiming import Timer +import logging +from knockknock import slack_sender +import os +from pathlib import Path +import shortuuid +import typer +import xarray as xr + +from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER +from mlde_utils.training.dataset import load_raw_dataset_split + +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s - %(filename)s - %(asctime)s - %(message)s", +) +logger = logging.getLogger() +logger.setLevel("INFO") + +app = typer.Typer() + + +@app.callback() +def callback(): + pass + + +def _np_samples_to_xr(np_samples, coords, target_transform, cf_data_vars): + coords = {**dict(coords)} + + pred_pr_dims = ["ensemble_member", "time", "grid_latitude", "grid_longitude"] + pred_pr_attrs = { + "grid_mapping": "rotated_latitude_longitude", + "standard_name": "pred_pr", + "units": "kg m-2 s-1", + } + pred_pr_var = (pred_pr_dims, np_samples, pred_pr_attrs) + + data_vars = {**cf_data_vars, "target_pr": pred_pr_var} + + pred_ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs={}) + + if target_transform is not None: + pred_ds = target_transform.invert(pred_ds) + + pred_ds = pred_ds.rename({"target_pr": "pred_pr"}) + + return pred_ds + + +def _sample_id(variable: str, eval_ds: xr.Dataset) -> xr.Dataset: + """Create a Dataset of pr samples set to the values the given variable from the dataset.""" + cf_data_vars = { + key: eval_ds.data_vars[key] + for key in [ + "rotated_latitude_longitude", + "time_bnds", + "grid_latitude_bnds", + "grid_longitude_bnds", + ] + if key in eval_ds.variables + } + coords = eval_ds.coords + np_samples = eval_ds[variable].data + xr_samples = _np_samples_to_xr( + np_samples, coords=coords, target_transform=None, cf_data_vars=cf_data_vars + ) + + return xr_samples + + +@app.command() +@Timer(name="sample", text="{name}: {minutes:.1f} minutes", logger=logging.info) +@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") +def as_input( + workdir: Path, + dataset: str = typer.Option(...), + variable: str = "pr", + split: str = "val", + ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, +): + """ + Use a given variable from the dataset to create a file of prediction samples. + + Commonly used to create samples based on an already processed variable like using a bilinearly interpolated coarse resolution variable as the predicted "high-resolution" value directly. + """ + output_dirpath = samples_path( + workdir=workdir, + checkpoint=f"epoch-0", + dataset=dataset, + input_xfm="none", + split=split, + ensemble_member=ensemble_member, + ) + os.makedirs(output_dirpath, exist_ok=True) + + eval_ds = load_raw_dataset_split(dataset, split).sel( + ensemble_member=[ensemble_member] + ) + xr_samples = _sample_id(variable, eval_ds) + + output_filepath = os.path.join(output_dirpath, f"predictions-{shortuuid.uuid()}.nc") + + logger.info(f"Saving predictions to {output_filepath}") + xr_samples.to_netcdf(output_filepath) diff --git a/src/ml_downscaling_emulator/deterministic/__init__.py b/src/ml_downscaling_emulator/configs/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/deterministic/__init__.py rename to src/ml_downscaling_emulator/configs/__init__.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_12em_configs.py b/src/ml_downscaling_emulator/configs/default_ukcp_local_pr_12em_configs.py similarity index 71% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_12em_configs.py rename to src/ml_downscaling_emulator/configs/default_ukcp_local_pr_12em_configs.py index 6679e4b0..2bf47e88 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_12em_configs.py +++ b/src/ml_downscaling_emulator/configs/default_ukcp_local_pr_12em_configs.py @@ -1,7 +1,7 @@ import ml_collections import torch -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs as get_base_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs as get_base_configs def get_default_configs(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_1em_configs.py b/src/ml_downscaling_emulator/configs/default_ukcp_local_pr_1em_configs.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_1em_configs.py rename to src/ml_downscaling_emulator/configs/default_ukcp_local_pr_1em_configs.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/__init__.py b/src/ml_downscaling_emulator/configs/deterministic/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/__init__.py rename to src/ml_downscaling_emulator/configs/deterministic/__init__.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/configs/deterministic/default_configs.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py rename to src/ml_downscaling_emulator/configs/deterministic/default_configs.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py rename to src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py index bb0d55bc..034eaf53 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py +++ b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py @@ -17,7 +17,7 @@ # Lint as: python3 """Training NCSN++ on precip data in a deterministic fashion.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs +from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs def get_config(): config = get_default_configs() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_plain_unet.py similarity index 93% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py rename to src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_plain_unet.py index fc5ce7c6..aad404b0 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py +++ b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_plain_unet.py @@ -21,7 +21,7 @@ but training it in a deterministic fashion. """ -from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs +from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs def get_config(): config = get_default_configs() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py similarity index 93% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py rename to src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py index 2cde4593..ceb235c8 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py +++ b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py @@ -21,7 +21,7 @@ but training it in a deterministic fashion. """ -from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs +from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs def get_config(): config = get_default_configs() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py rename to src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py index 8a442220..b661fa5a 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py +++ b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py @@ -17,7 +17,7 @@ # Lint as: python3 """Training NCSN++ on precip data in a deterministic fashion.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs +from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs def get_config(): config = get_default_configs() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_debug.py similarity index 90% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py rename to src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_debug.py index 2a8153fa..6430a23c 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py +++ b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_debug.py @@ -17,7 +17,7 @@ # Lint as: python3 """Debug config for training in a deterministic fashion.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs +from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs def get_config(): config = get_default_configs() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_plain_unet_debug.py similarity index 89% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py rename to src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_plain_unet_debug.py index f442f45a..904766a8 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py +++ b/src/ml_downscaling_emulator/configs/deterministic/ukcp_local_pr_plain_unet_debug.py @@ -21,7 +21,7 @@ but training it in a deterministic fashion. """ -from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.ukcp_local_pr_12em_tuned_plain_unet import get_config as get_default_configs +from ml_downscaling_emulator.configs.deterministic.ukcp_local_pr_12em_tuned_plain_unet import get_config as get_default_configs def get_config(): config = get_default_configs() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/__init__.py b/src/ml_downscaling_emulator/configs/subvpsde/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/__init__.py rename to src/ml_downscaling_emulator/configs/subvpsde/__init__.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py index 997b3911..fa867aaf 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_debug.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_mv_debug.py similarity index 91% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_debug.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_mv_debug.py index 15d0a436..0abcaefc 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_debug.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_mv_debug.py @@ -16,7 +16,7 @@ # Lint as: python3 """Training conditional U-Net on precip data with sub-VP SDE. DEBUGGING ONLY""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py similarity index 93% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py index b6f6f394..3e203001 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py index 37875926..fc4eebd1 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py index a060f971..ad063779 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_debug.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_debug.py similarity index 91% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_debug.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_debug.py index 832cabe7..469dd0b5 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_debug.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_debug.py @@ -16,7 +16,7 @@ # Lint as: python3 """Training conditional U-Net on precip data with sub-VP SDE. DEBUGGING ONLY""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_rh_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_rh_12em_cncsnpp_continuous.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_rh_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_rh_12em_cncsnpp_continuous.py index 2ddd6ea9..bcde1198 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_rh_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_rh_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_tmean_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_tmean_12em_cncsnpp_continuous.py similarity index 94% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_tmean_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_tmean_12em_cncsnpp_continuous.py index e815d14c..82f03cd4 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_tmean_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_tmean_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/__init__.py b/src/ml_downscaling_emulator/configs/vesde/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/__init__.py rename to src/ml_downscaling_emulator/configs/vesde/__init__.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py b/src/ml_downscaling_emulator/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py similarity index 93% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py index a1086709..138643d4 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/ukcp_local_pr_cunet_continuous.py b/src/ml_downscaling_emulator/configs/vesde/ukcp_local_pr_cunet_continuous.py similarity index 93% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/ukcp_local_pr_cunet_continuous.py rename to src/ml_downscaling_emulator/configs/vesde/ukcp_local_pr_cunet_continuous.py index 58e3842e..48cf5cbe 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/ukcp_local_pr_cunet_continuous.py +++ b/src/ml_downscaling_emulator/configs/vesde/ukcp_local_pr_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training UNet on XArray with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs +from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/__init__.py b/src/ml_downscaling_emulator/configs/vpsde/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/__init__.py rename to src/ml_downscaling_emulator/configs/vpsde/__init__.py diff --git a/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py b/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py deleted file mode 100644 index 3cf49481..00000000 --- a/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py +++ /dev/null @@ -1,38 +0,0 @@ -import ml_collections -import torch - - -def get_config(): - config = ml_collections.ConfigDict() - - config.training = training = ml_collections.ConfigDict() - training.n_epochs = 100 - training.batch_size = 64 - training.snapshot_freq = 25 - training.log_freq = 50 - training.eval_freq = 1000 - - config.eval = evaluate = ml_collections.ConfigDict() - evaluate.batch_size = 64 - - config.data = data = ml_collections.ConfigDict() - data.dataset_name = "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr" - data.input_transform_key = "stan" - data.target_transform_key = "sqrturrecen" - data.input_transform_dataset = None - data.time_inputs = False - - config.model = model = ml_collections.ConfigDict() - model.name = "u-net" - model.loss = "MSELoss" - - config.optim = optim = ml_collections.ConfigDict() - optim.optimizer = "Adam" - optim.lr = 2e-4 - - config.seed = 42 - config.device = ( - torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") - ) - - return config diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py deleted file mode 100644 index 4f0b3c04..00000000 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ /dev/null @@ -1,262 +0,0 @@ -import logging -import os - -from absl import flags -from codetiming import Timer -import numpy as np -import torch -from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm -import yaml - -from mlde_utils import DatasetMetadata - -from ..training import log_epoch, track_run -from .utils import restore_checkpoint, save_checkpoint, create_model -from ..data import get_dataloader - -FLAGS = flags.FLAGS -EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME") - -logging.basicConfig( - level=logging.INFO, - format="%(levelname)s - %(filename)s - %(asctime)s - %(message)s", -) -logger = logging.getLogger() -logger.setLevel("INFO") - - -def val_loss(config, val_dl, eval_step_fn, state): - val_set_loss = 0.0 - for val_cond_batch, val_x_batch, val_time_batch in val_dl: - val_x_batch = val_x_batch.to(config.device) - val_cond_batch = val_cond_batch.to(config.device) - - val_batch_loss = eval_step_fn(state, val_x_batch, val_cond_batch) - - # Progress - val_set_loss += val_batch_loss.item() - val_set_loss = val_set_loss / len(val_dl) - - return val_set_loss - - -@Timer(name="train", text="{name}: {minutes:.1f} minutes", logger=logging.info) -def train(config, workdir): - os.makedirs(workdir, exist_ok=True) - - gfile_stream = open(os.path.join(workdir, "stdout.txt"), "w") - handler = logging.StreamHandler(gfile_stream) - formatter = logging.Formatter( - "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" - ) - handler.setFormatter(formatter) - logger.addHandler(handler) - - # Create transform saving directory - transform_dir = os.path.join(workdir, "transforms") - os.makedirs(transform_dir, exist_ok=True) - - # Create directories for experimental logs - sample_dir = os.path.join(workdir, "samples") - os.makedirs(sample_dir, exist_ok=True) - - tb_dir = os.path.join(workdir, "tensorboard") - os.makedirs(tb_dir, exist_ok=True) - - logging.info(f"Starting {os.path.basename(__file__)}") - - # Create checkpoints directory - checkpoint_dir = os.path.join(workdir, "checkpoints") - os.makedirs(checkpoint_dir, exist_ok=True) - # Intermediate checkpoints to resume training after pre-emption in cloud environments - checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") - os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) - - dataset_meta = DatasetMetadata(config.data.dataset_name) - - # Build dataloaders - train_dl, _, _ = get_dataloader( - config.data.dataset_name, - config.data.dataset_name, - config.data.dataset_name, - config.data.input_transform_key, - config.data.target_transform_key, - transform_dir, - batch_size=config.training.batch_size, - split="train", - ensemble_members=dataset_meta.ensemble_members(), - include_time_inputs=config.data.time_inputs, - evaluation=False, - ) - val_dl, _, _ = get_dataloader( - config.data.dataset_name, - config.data.dataset_name, - config.data.dataset_name, - config.data.input_transform_key, - config.data.target_transform_key, - transform_dir, - batch_size=config.training.batch_size, - split="val", - ensemble_members=dataset_meta.ensemble_members(), - include_time_inputs=config.data.time_inputs, - evaluation=False, - ) - - # Setup model, loss and optimiser - num_predictors = train_dl.dataset[0][0].shape[0] - model = torch.nn.DataParallel( - create_model(config, num_predictors).to(device=config.device) - ) - - if config.model.loss == "MSELoss": - criterion = torch.nn.MSELoss().to(config.device) - else: - raise NotImplementedError(f"Loss {config.model.loss} not supported yet!") - - if config.optim.optimizer == "Adam": - optimizer = torch.optim.Adam(model.parameters(), lr=config.optim.lr) - else: - raise NotImplementedError( - f"Optimizer {config.optim.optimizer} not supported yet!" - ) - - state = dict(optimizer=optimizer, model=model, step=0, epoch=0) - # Resume training when intermediate checkpoints are detected - state, _ = restore_checkpoint(checkpoint_meta_dir, state, config.device) - initial_epoch = ( - int(state["epoch"]) + 1 - ) # start from the epoch after the one currently reached - - initial_epoch = ( - int(state["epoch"]) + 1 - ) # start from the epoch after the one currently reached - # step = state["step"] - - def loss_fn(model, batch, cond): - return criterion(model(cond), batch) - - def optimize_fn(optimizer, params, step, lr, warmup=5000, grad_clip=1.0): - """Optimizes with warmup and gradient clipping (disabled if negative).""" - if warmup > 0: - for g in optimizer.param_groups: - g["lr"] = lr * np.minimum(step / warmup, 1.0) - if grad_clip >= 0: - torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) - optimizer.step() - - # Compute validation loss - def eval_step_fn(state, batch, cond): - """Running one step of training or evaluation. - - Args: - state: A dictionary of training information, containing the score model, optimizer, - EMA status, and number of optimization steps. - batch: A mini-batch of training/evaluation data to model. - cond: A mini-batch of conditioning inputs. - - Returns: - loss: The average loss value of this state. - """ - model = state["model"] - with torch.no_grad(): - loss = loss_fn(model, batch, cond) - - return loss - - def train_step_fn(state, batch, cond): - """Running one step of training or evaluation. - - Args: - state: A dictionary of training information, containing the score model, optimizer, - EMA status, and number of optimization steps. - batch: A mini-batch of training/evaluation data to model. - cond: A mini-batch of conditioning inputs. - - Returns: - loss: The average loss value of this state. - """ - model = state["model"] - optimizer = state["optimizer"] - optimizer.zero_grad() - loss = loss_fn(model, batch, cond) - loss.backward() - optimize_fn( - optimizer, model.parameters(), step=state["step"], lr=config.optim.lr - ) - state["step"] += 1 - - return loss - - # save the config - config_path = os.path.join(workdir, "config.yml") - with open(config_path, "w") as f: - yaml.dump(config, f) - - run_name = os.path.basename(workdir) - run_config = dict( - dataset=config.data.dataset_name, - input_transform_key=config.data.input_transform_key, - target_transform_key=config.data.target_transform_key, - architecture=config.model.name, - name=run_name, - loss=config.model.loss, - time_inputs=config.data.time_inputs, - ) - - with track_run( - EXPERIMENT_NAME, run_name, run_config, [config.model.name, "baseline"], tb_dir - ) as (wandb_run, tb_writer): - # Fit model - - logging.info("Starting training loop at epoch %d." % (initial_epoch,)) - - for epoch in range(initial_epoch, config.training.n_epochs + 1): - state["epoch"] = epoch - # Update model based on training data - model.train() - - train_set_loss = 0.0 - with logging_redirect_tqdm(): - with tqdm( - total=len(train_dl.dataset), - desc=f"Epoch {state['epoch']}", - unit=" timesteps", - ) as pbar: - for (cond_batch, x_batch, time_batch) in train_dl: - cond_batch = cond_batch.to(config.device) - x_batch = x_batch.to(config.device) - - train_batch_loss = train_step_fn(state, x_batch, cond_batch) - train_set_loss += train_batch_loss.item() - - # Log progress so far on epoch - pbar.update(cond_batch.shape[0]) - - train_set_loss = train_set_loss / len(train_dl) - - # Save a temporary checkpoint to resume training after each epoch - save_checkpoint(checkpoint_meta_dir, state) - - # Report the loss on an validation dataset each epoch - model.eval() - val_set_loss = val_loss(config, val_dl, eval_step_fn, state) - epoch_metrics = { - "epoch/train/loss": train_set_loss, - "epoch/val/loss": val_set_loss, - } - log_epoch(state["epoch"], epoch_metrics, wandb_run, tb_writer) - # Checkpoint model - if ( - state["epoch"] != 0 - and state["epoch"] % config.training.snapshot_freq == 0 - ) or state["epoch"] == config.training.n_epochs: - checkpoint_path = os.path.join( - checkpoint_dir, f"epoch_{state['epoch']}.pth" - ) - save_checkpoint(checkpoint_path, state) - logging.info( - f"epoch: {state['epoch']}, checkpoint saved to {checkpoint_path}" - ) - - logging.info(f"Finished {os.path.basename(__file__)}") diff --git a/src/ml_downscaling_emulator/deterministic/sampling.py b/src/ml_downscaling_emulator/deterministic/sampling.py deleted file mode 100644 index f5737719..00000000 --- a/src/ml_downscaling_emulator/deterministic/sampling.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -import torch -from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm -import xarray as xr - - -def generate_np_samples(model, cond_batch): - model.eval() - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - cond_batch = cond_batch.to(device) - - samples = model(cond_batch) - # drop the feature channel dimension (only have target pr as output) - samples = samples.squeeze(dim=1) - # extract numpy array - samples = samples.cpu().detach().numpy() - return samples - - -def np_samples_to_xr(np_samples, coords, target_transform, cf_data_vars): - coords = {**dict(coords)} - - pred_pr_dims = ["ensemble_member", "time", "grid_latitude", "grid_longitude"] - pred_pr_attrs = { - "grid_mapping": "rotated_latitude_longitude", - "standard_name": "pred_pr", - "units": "kg m-2 s-1", - } - pred_pr_var = (pred_pr_dims, np_samples, pred_pr_attrs) - - data_vars = {**cf_data_vars, "target_pr": pred_pr_var} - - pred_ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs={}) - - if target_transform is not None: - pred_ds = target_transform.invert(pred_ds) - - pred_ds = pred_ds.rename({"target_pr": "pred_pr"}) - - return pred_ds - - -def sample(model, eval_dl, target_transform): - cf_data_vars = { - key: eval_dl.dataset.ds.data_vars[key] - for key in [ - "rotated_latitude_longitude", - "time_bnds", - "grid_latitude_bnds", - "grid_longitude_bnds", - ] - } - preds = [] - with logging_redirect_tqdm(): - with tqdm( - total=len(eval_dl.dataset), desc=f"Sampling", unit=" timesteps" - ) as pbar: - with torch.no_grad(): - for cond_batch, _, time_batch in eval_dl: - coords = eval_dl.dataset.ds.sel(time=time_batch).coords - batch_np_samples = generate_np_samples(model, cond_batch) - # add ensemble member axis to np samples - batch_np_samples = batch_np_samples[np.newaxis, :] - xr_samples = np_samples_to_xr( - batch_np_samples, coords, target_transform, cf_data_vars - ) - preds.append(xr_samples) - - pbar.update(cond_batch.shape[0]) - - ds = xr.combine_by_coords( - preds, - compat="no_conflicts", - combine_attrs="drop_conflicts", - coords="all", - join="inner", - data_vars="all", - ) - - return ds - - -def sample_id(variable: str, eval_ds: xr.Dataset) -> xr.Dataset: - """Create a Dataset of pr samples set to the values the given variable from the dataset.""" - cf_data_vars = { - key: eval_ds.data_vars[key] - for key in [ - "rotated_latitude_longitude", - "time_bnds", - "grid_latitude_bnds", - "grid_longitude_bnds", - ] - if key in eval_ds.variables - } - coords = eval_ds.coords - np_samples = eval_ds[variable].data - xr_samples = np_samples_to_xr( - np_samples, coords=coords, target_transform=None, cf_data_vars=cf_data_vars - ) - - return xr_samples diff --git a/src/ml_downscaling_emulator/deterministic/utils.py b/src/ml_downscaling_emulator/deterministic/utils.py deleted file mode 100644 index 5ee423c6..00000000 --- a/src/ml_downscaling_emulator/deterministic/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -import os -import torch.nn as nn - -from ..unet import unet - - -def create_model(config, num_predictors): - if config.model.name == "u-net": - return unet.UNet(num_predictors, 1) - if config.model.name == "debug": - return nn.Conv2d(num_predictors, 1, 3, stride=1, padding=1) - raise NotImplementedError(f"Model {config.model.name} not supported yet!") - - -def restore_checkpoint(ckpt_dir, state, device): - import torch - - if not os.path.exists(ckpt_dir): - os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) - logging.warning( - f"No checkpoint found at {ckpt_dir}." f"Returned the same state as input" - ) - return state, False - else: - loaded_state = torch.load(ckpt_dir, map_location=device) - state["optimizer"].load_state_dict(loaded_state["optimizer"]) - state["model"].load_state_dict(loaded_state["model"], strict=False) - state["step"] = loaded_state["step"] - state["epoch"] = loaded_state["epoch"] - logging.info( - f"Checkpoint found at {ckpt_dir}. " - f"Returned the state from {state['epoch']}/{state['step']}" - ) - return state, True - - -def save_checkpoint(ckpt_dir, state): - import torch - - saved_state = { - "optimizer": state["optimizer"].state_dict(), - "model": state["model"].state_dict(), - "step": state["step"], - "epoch": state["epoch"], - } - torch.save(saved_state, ckpt_dir) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/likelihood.py b/src/ml_downscaling_emulator/likelihood.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/likelihood.py rename to src/ml_downscaling_emulator/likelihood.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py b/src/ml_downscaling_emulator/losses.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/losses.py rename to src/ml_downscaling_emulator/losses.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/__init__.py b/src/ml_downscaling_emulator/models/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/__init__.py rename to src/ml_downscaling_emulator/models/__init__.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/cncsnpp.py b/src/ml_downscaling_emulator/models/cncsnpp.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/cncsnpp.py rename to src/ml_downscaling_emulator/models/cncsnpp.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/cunet.py b/src/ml_downscaling_emulator/models/cunet.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/cunet.py rename to src/ml_downscaling_emulator/models/cunet.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ddpm.py b/src/ml_downscaling_emulator/models/ddpm.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/ddpm.py rename to src/ml_downscaling_emulator/models/ddpm.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/det_cunet.py b/src/ml_downscaling_emulator/models/det_cunet.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/det_cunet.py rename to src/ml_downscaling_emulator/models/det_cunet.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py b/src/ml_downscaling_emulator/models/ema.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py rename to src/ml_downscaling_emulator/models/ema.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/layers.py b/src/ml_downscaling_emulator/models/layers.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/layers.py rename to src/ml_downscaling_emulator/models/layers.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/layerspp.py b/src/ml_downscaling_emulator/models/layerspp.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/layerspp.py rename to src/ml_downscaling_emulator/models/layerspp.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py b/src/ml_downscaling_emulator/models/location_params.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py rename to src/ml_downscaling_emulator/models/location_params.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ncsnpp.py b/src/ml_downscaling_emulator/models/ncsnpp.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/ncsnpp.py rename to src/ml_downscaling_emulator/models/ncsnpp.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ncsnv2.py b/src/ml_downscaling_emulator/models/ncsnv2.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/ncsnv2.py rename to src/ml_downscaling_emulator/models/ncsnv2.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/normalization.py b/src/ml_downscaling_emulator/models/normalization.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/normalization.py rename to src/ml_downscaling_emulator/models/normalization.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/up_or_down_sampling.py b/src/ml_downscaling_emulator/models/up_or_down_sampling.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/up_or_down_sampling.py rename to src/ml_downscaling_emulator/models/up_or_down_sampling.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/utils.py b/src/ml_downscaling_emulator/models/utils.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/models/utils.py rename to src/ml_downscaling_emulator/models/utils.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/__init__.py b/src/ml_downscaling_emulator/op/__init__.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/__init__.py rename to src/ml_downscaling_emulator/op/__init__.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/fused_act.py b/src/ml_downscaling_emulator/op/fused_act.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/fused_act.py rename to src/ml_downscaling_emulator/op/fused_act.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/fused_bias_act.cpp b/src/ml_downscaling_emulator/op/fused_bias_act.cpp similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/fused_bias_act.cpp rename to src/ml_downscaling_emulator/op/fused_bias_act.cpp diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/fused_bias_act_kernel.cu b/src/ml_downscaling_emulator/op/fused_bias_act_kernel.cu similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/fused_bias_act_kernel.cu rename to src/ml_downscaling_emulator/op/fused_bias_act_kernel.cu diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/upfirdn2d.cpp b/src/ml_downscaling_emulator/op/upfirdn2d.cpp similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/upfirdn2d.cpp rename to src/ml_downscaling_emulator/op/upfirdn2d.cpp diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/upfirdn2d.py b/src/ml_downscaling_emulator/op/upfirdn2d.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/upfirdn2d.py rename to src/ml_downscaling_emulator/op/upfirdn2d.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/op/upfirdn2d_kernel.cu b/src/ml_downscaling_emulator/op/upfirdn2d_kernel.cu similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/op/upfirdn2d_kernel.cu rename to src/ml_downscaling_emulator/op/upfirdn2d_kernel.cu diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/run_lib.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py rename to src/ml_downscaling_emulator/run_lib.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py b/src/ml_downscaling_emulator/sampling.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/sampling.py rename to src/ml_downscaling_emulator/sampling.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/__init__.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/vesde/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/vpsde/__init__.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/vpsde/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/utils.py b/src/ml_downscaling_emulator/score_sde_pytorch/utils.py deleted file mode 100644 index 97f4b681..00000000 --- a/src/ml_downscaling_emulator/score_sde_pytorch/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import os -import logging - - -def restore_checkpoint(ckpt_dir, state, device): - if not os.path.exists(ckpt_dir): - os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) - logging.warning(f"No checkpoint found at {ckpt_dir}. " - f"Returned the same state as input") - return state, False - else: - loaded_state = torch.load(ckpt_dir, map_location=device) - state['optimizer'].load_state_dict(loaded_state['optimizer']) - state['model'].load_state_dict(loaded_state['model'], strict=False) - state['ema'].load_state_dict(loaded_state['ema']) - state['location_params'].load_state_dict(loaded_state['location_params']) - state['step'] = loaded_state['step'] - state['epoch'] = loaded_state['epoch'] - logging.info( - f"Checkpoint found at {ckpt_dir}. " - f"Returned the state from {state['epoch']}/{state['step']}" - ) - return state, True - - -def save_checkpoint(ckpt_dir, state): - saved_state = { - 'optimizer': state['optimizer'].state_dict(), - 'model': state['model'].state_dict(), - 'ema': state['ema'].state_dict(), - 'step': state['step'], - 'epoch': state['epoch'], - 'location_params': state['location_params'].state_dict(), - } - torch.save(saved_state, ckpt_dir) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/sde_lib.py b/src/ml_downscaling_emulator/sde_lib.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch/sde_lib.py rename to src/ml_downscaling_emulator/sde_lib.py diff --git a/src/ml_downscaling_emulator/utils.py b/src/ml_downscaling_emulator/utils.py index a15ab7a5..5dc3270e 100644 --- a/src/ml_downscaling_emulator/utils.py +++ b/src/ml_downscaling_emulator/utils.py @@ -1,4 +1,60 @@ -"""Helper methods""" +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Significant modifications to the original work have been made by Henry Addison +# to allow for location-specific parameters and iterating by epoch using PyTorch +# DataLoaders and helpers for determining a model size. + +import torch +import os +import logging + + +def restore_checkpoint(ckpt_dir, state, device): + if not os.path.exists(ckpt_dir): + os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) + logging.warning( + f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input" + ) + return state, False + else: + loaded_state = torch.load(ckpt_dir, map_location=device) + state["optimizer"].load_state_dict(loaded_state["optimizer"]) + state["model"].load_state_dict(loaded_state["model"], strict=False) + state["ema"].load_state_dict(loaded_state["ema"]) + state["location_params"].load_state_dict(loaded_state["location_params"]) + state["step"] = loaded_state["step"] + state["epoch"] = loaded_state["epoch"] + logging.info( + f"Checkpoint found at {ckpt_dir}. " + f"Returned the state from {state['epoch']}/{state['step']}" + ) + return state, True + + +def save_checkpoint(ckpt_dir, state): + saved_state = { + "optimizer": state["optimizer"].state_dict(), + "model": state["model"].state_dict(), + "ema": state["ema"].state_dict(), + "step": state["step"], + "epoch": state["epoch"], + "location_params": state["location_params"].state_dict(), + } + torch.save(saved_state, ckpt_dir) def param_count(model): diff --git a/tests/ml_downscaling_emulator/deterministic/test_sampling.py b/tests/ml_downscaling_emulator/bin/test_sample.py similarity index 60% rename from tests/ml_downscaling_emulator/deterministic/test_sampling.py rename to tests/ml_downscaling_emulator/bin/test_sample.py index 4c633a96..574b4710 100644 --- a/tests/ml_downscaling_emulator/deterministic/test_sampling.py +++ b/tests/ml_downscaling_emulator/bin/test_sample.py @@ -1,14 +1,14 @@ import xarray as xr -from ml_downscaling_emulator.deterministic.sampling import sample_id +from ml_downscaling_emulator.bin.sample import _sample_id def test_sample_id(dataset: xr.Dataset): - """Ensure the sample_id function creates a set of predictions using the values of the given variable.""" + """Ensure the _sample_id bin function creates a set of predictions using the values of the given variable.""" variable = "linpr" em_dataset = dataset.sel(ensemble_member=["01"]) - xr_samples = sample_id(variable, em_dataset) + xr_samples = _sample_id(variable, em_dataset) assert (xr_samples["pred_pr"].values == em_dataset["linpr"].values).all() for dim in ["time", "grid_latitude", "grid_longitude"]: diff --git a/tests/smoke-tests/test-det-debug-cunet b/tests/smoke-tests/test-det-debug-cunet index 9decc4d3..078b92a5 100755 --- a/tests/smoke-tests/test-det-debug-cunet +++ b/tests/smoke-tests/test-det-debug-cunet @@ -5,7 +5,7 @@ set -euo pipefail config_name="ukcp_local_pr_debug" workdir="output/test/deterministic/${config_name}/test-run" -config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py" +config_path="src/ml_downscaling_emulator/configs/deterministic/${config_name}.py" loc_spec_channels=0 diff --git a/tests/smoke-tests/test-det-det_cunet b/tests/smoke-tests/test-det-det_cunet index 3ca5acba..0ce613ad 100755 --- a/tests/smoke-tests/test-det-det_cunet +++ b/tests/smoke-tests/test-det-det_cunet @@ -5,7 +5,7 @@ set -euo pipefail config_name="ukcp_local_pr_plain_unet_debug" workdir="output/test/deterministic/${config_name}/test-run" -config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py" +config_path="src/ml_downscaling_emulator/configs/deterministic/${config_name}.py" loc_spec_channels=2 diff --git a/tests/smoke-tests/test-subvpsde-debug-cunet b/tests/smoke-tests/test-subvpsde-debug-cunet index 4ae83206..ce537fc2 100755 --- a/tests/smoke-tests/test-subvpsde-debug-cunet +++ b/tests/smoke-tests/test-subvpsde-debug-cunet @@ -7,7 +7,7 @@ config_name="ukcp_local_mv_debug" dataset="debug-sample-mv" workdir="output/test/${sde}/${config_name}/test-run" -config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${config_name}.py" +config_path="src/ml_downscaling_emulator/configs/${sde}/${config_name}.py" loc_spec_channels=2 train_batch_size=2 diff --git a/tox.ini b/tox.ini index 15d31e5a..ce29ec1d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [flake8] max-line-length = 88 extend-ignore = E203,E501,F541 -exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,src/ml_downscaling_emulator/unet,src/ml_downscaling_emulator/score_sde_pytorch +exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,src/ml_downscaling_emulator/unet,src/ml_downscaling_emulator/run_lib.py,src/ml_downscaling_emulator/sde_lib.py,src/ml_downscaling_emulator/likelihood.py,src/ml_downscaling_emulator/sampling.py,src/ml_downscaling_emulator/losses.py,src/ml_downscaling_emulator/models, src/ml_downscaling_emulator/op,src/ml_downscaling_emulator/configs max-complexity = 20