diff --git a/bin/deterministic/bp/queue-sampling b/bin/deterministic/bp/queue-sampling deleted file mode 100755 index fbf90515..00000000 --- a/bin/deterministic/bp/queue-sampling +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python -# setup jobs for sampling from a model - -import os -import subprocess -import sys - -import typer - -app = typer.Typer() - - -def sample_cmd( - dataset, epoch, samples_per_job, workdir, ensemble_member, input_transform_key=None -): - batch_size = 1024 - - sample_basecmd = ["mlde", "evaluate", "sample"] - - sample_opts = { - "--epoch": str(epoch), - "--num-samples": str(samples_per_job), - "--dataset": dataset, - "--batch-size": str(batch_size), - "--ensemble-member": str(ensemble_member), - } - - if input_transform_key is not None: - sample_opts["--input-transform-key"] = input_transform_key - - return ( - sample_basecmd - + [arg for item in sample_opts.items() for arg in item] - + [workdir] - ) - - -def queue_cmd(depends_on, sampling_jobs, sampling_duration): - queue_basecmd = ["lbatch"] - - queue_opts = { - "-a": os.getenv("HPC_PROJECT_CODE"), - "-g": "1", - "-m": "16", - "-q": "cnu,gpu", - "-t": str(sampling_duration), - "--condaenv": "cuda-downscaling", - "--array": f"1-{sampling_jobs}", - } - if depends_on is not None: - queue_opts["-d"] = str(depends_on) - - return queue_basecmd + [arg for item in queue_opts.items() for arg in item] - - -@app.command() -def main( - model_run_id: str, - cpm_dataset: str, - gcm_dataset: str, - epoch: int, - ensemble_member: str, - depends_on: int = None, - input_transform_key: str = None, -): - - sampling_jobs = 1 - samples_per_job = 3 - sampling_duration = 18 * samples_per_job - - workdir = f"{os.getenv('DERIVED_DATA')}/workdirs/u-net/{model_run_id}" - - shared_queue_cmd = queue_cmd(depends_on, sampling_jobs, sampling_duration) - - # sample CPM - full_cmd = ( - shared_queue_cmd - + ["--"] - + sample_cmd(cpm_dataset, epoch, samples_per_job, workdir, ensemble_member) - ) - print(" ".join(full_cmd).strip(), file=sys.stderr) - output = subprocess.run(full_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - # sample GCM - full_cmd = ( - shared_queue_cmd - + ["--"] - + sample_cmd(gcm_dataset, epoch, samples_per_job, workdir, input_transform_key) - ) - print(" ".join(full_cmd).strip(), file=sys.stderr) - output = subprocess.run(full_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - -if __name__ == "__main__": - app() diff --git a/bin/deterministic/bp/queue-training b/bin/deterministic/bp/queue-training deleted file mode 100755 index 184d1e84..00000000 --- a/bin/deterministic/bp/queue-training +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python -# setup jobs for training a model - -import os -import subprocess -import sys - -import typer - -app = typer.Typer() - - -def train_cmd(dataset, workdir, config_overrides=list): - train_basecmd = ["python", "bin/deterministic/main.py"] - - train_opts = { - "--config": "src/ml_downscaling_emulator/deterministic/configs/default.py", - "--workdir": workdir, - "--mode": "train", - } - - return ( - train_basecmd - + [arg for item in train_opts.items() for arg in item] - + [f"--config.data.dataset_name={dataset}"] - + config_overrides - ) - - -def queue_cmd(duration, memory): - queue_basecmd = ["lbatch"] - - queue_opts = { - "-a": os.getenv("HPC_PROJECT_CODE"), - "-g": "1", - "-m": str(memory), - "-q": "cnu,gpu", - "-t": str(duration), - "--condaenv": "cuda-downscaling", - } - - return queue_basecmd + [arg for item in queue_opts.items() for arg in item] - - -@app.command( - context_settings={ - "allow_extra_args": True, - "ignore_unknown_options": True, - } -) -def main( - ctx: typer.Context, - model_run_id: str, - cpm_dataset: str, - memory: int = 64, - duration: int = 72, -): - # Add any other config on the commandline for training - # --config.data.input_transform_key=spatial - - workdir = f"{os.getenv('DERIVED_DATA')}/workdirs/u-net/{model_run_id}" - - full_cmd = ( - queue_cmd(duration=duration, memory=memory) - + ["--"] - + train_cmd(cpm_dataset, workdir, ctx.args) - ) - print(" ".join(full_cmd).strip(), file=sys.stderr) - output = subprocess.run(full_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - -if __name__ == "__main__": - app() diff --git a/bin/deterministic/bp/train-sample b/bin/deterministic/bp/train-sample deleted file mode 100755 index fe450ad3..00000000 --- a/bin/deterministic/bp/train-sample +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python -# setup jobs for training and then sampling from a model - -import os -import subprocess -import sys - -import typer - -app = typer.Typer() - - -@app.command( - context_settings={ - "allow_extra_args": True, - "ignore_unknown_options": True, - } -) -def main( - ctx: typer.Context, - run_id: str = typer.Option(...), - cpm_dataset: str = typer.Option(...), - gcm_dataset: str = typer.Option(...), - epochs: int = typer.Option(...), -): - # Add any other config on the commandline for training - # --config.data.input_transform_key=spatial - - # train - train_cmd = ( - [f"{os.path.dirname(__file__)}/queue-training"] - + [run_id, cpm_dataset, "--epochs", str(epochs)] - + ctx.args - ) - print(" ".join(train_cmd).strip(), file=sys.stderr) - output = subprocess.run(train_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - training_job_id = output.stdout.decode("utf8").strip() - print(training_job_id) - - # sample - sample_cmd = [f"{os.path.dirname(__file__)}/queue-sampling"] + [ - run_id, - cpm_dataset, - gcm_dataset, - str(epochs), - "--depends-on", - training_job_id, - ] - print(" ".join(sample_cmd).strip(), file=sys.stderr) - output = subprocess.run(sample_cmd, capture_output=True) - print(output.stderr.decode("utf8").strip(), file=sys.stderr) - print(output.stdout.decode("utf8").strip()) - - -if __name__ == "__main__": - app() diff --git a/bin/deterministic/main.py b/bin/deterministic/main.py deleted file mode 100644 index 80153160..00000000 --- a/bin/deterministic/main.py +++ /dev/null @@ -1,43 +0,0 @@ -import ml_downscaling_emulator.deterministic.run_lib as run_lib -from absl import app -from absl import flags -from ml_collections.config_flags import config_flags -import logging -import os - -from knockknock import slack_sender - -FLAGS = flags.FLAGS - -config_flags.DEFINE_config_file( - "config", None, "Training configuration.", lock_config=True -) -flags.DEFINE_string("workdir", None, "Work directory.") -flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.") -flags.mark_flags_as_required(["workdir", "config", "mode"]) - - -@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") -def main(argv): - if FLAGS.mode == "train": - # Create the working directory - os.makedirs(FLAGS.workdir, exist_ok=True) - # Set logger so that it outputs to both console and file - # Make logging work for both disk and Google Cloud Storage - gfile_stream = open(os.path.join(FLAGS.workdir, "stdout.txt"), "w") - handler = logging.StreamHandler(gfile_stream) - formatter = logging.Formatter( - "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" - ) - handler.setFormatter(formatter) - logger = logging.getLogger() - logger.addHandler(handler) - logger.setLevel("INFO") - # Run the training pipeline - run_lib.train(FLAGS.config, FLAGS.workdir) - else: - raise ValueError(f"Mode {FLAGS.mode} not recognized.") - - -if __name__ == "__main__": - app.run(main) diff --git a/bin/deterministic/model-size b/bin/deterministic/model-size deleted file mode 100755 index ab955065..00000000 --- a/bin/deterministic/model-size +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python -# calculate the number of parameters in a deterministic model - -import logging -import os -from pathlib import Path - -from ml_collections import config_dict -from mlde_utils.training.dataset import get_variables -import torch -import typer -import yaml - - -from ml_downscaling_emulator.deterministic.utils import create_model -from ml_downscaling_emulator.utils import model_size, param_count - - -logger = logging.getLogger() -logger.setLevel("INFO") - -app = typer.Typer() - - -def load_config(config_path): - logger.info(f"Loading config from {config_path}") - with open(config_path) as f: - config = config_dict.ConfigDict(yaml.unsafe_load(f)) - - return config - - -def load_model(config): - num_predictors = len(get_variables(config.data.dataset_name)[0]) - if config.data.time_inputs: - num_predictors += 3 - model = torch.nn.DataParallel( - create_model(config, num_predictors).to(device=config.device) - ) - optimizer = torch.optim.Adam(model.parameters()) - state = dict(step=0, epoch=0, optimizer=optimizer, model=model) - - return state - - -@app.command() -def main( - workdir: Path, -): - config_path = os.path.join(workdir, "config.yml") - config = load_config(config_path) - model = load_model(config)["model"] - num_score_model_parameters = param_count(model) - - typer.echo(f"Model has {num_score_model_parameters} parameters") - - size_all_mb = model_size(model) - - typer.echo("model size: {:.3f}MB".format(size_all_mb)) - - -if __name__ == "__main__": - app() diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index 5146ddbd..f38ec3d7 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -1,20 +1,14 @@ from codetiming import Timer import logging from knockknock import slack_sender -from ml_collections import config_dict import os from pathlib import Path import shortuuid -import torch import typer -import yaml +import xarray as xr from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER from mlde_utils.training.dataset import load_raw_dataset_split -from ..deterministic import sampling -from ..deterministic.utils import create_model, restore_checkpoint -from ..data import get_dataloader - logging.basicConfig( level=logging.INFO, @@ -31,93 +25,48 @@ def callback(): pass -def load_config(config_path): - logger.info(f"Loading config from {config_path}") - with open(config_path) as f: - config = config_dict.ConfigDict(yaml.unsafe_load(f)) +def _np_samples_to_xr(np_samples, coords, target_transform, cf_data_vars): + coords = {**dict(coords)} - return config + pred_pr_dims = ["ensemble_member", "time", "grid_latitude", "grid_longitude"] + pred_pr_attrs = { + "grid_mapping": "rotated_latitude_longitude", + "standard_name": "pred_pr", + "units": "kg m-2 s-1", + } + pred_pr_var = (pred_pr_dims, np_samples, pred_pr_attrs) + data_vars = {**cf_data_vars, "target_pr": pred_pr_var} -def load_model(config, num_predictors, ckpt_filename): - model = torch.nn.DataParallel( - create_model(config, num_predictors).to(device=config.device) - ) - optimizer = torch.optim.Adam(model.parameters()) - state = dict(step=0, epoch=0, optimizer=optimizer, model=model) - state, loaded = restore_checkpoint(ckpt_filename, state, config.device) - assert loaded, "Did not load state from checkpoint" + pred_ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs={}) - return state + if target_transform is not None: + pred_ds = target_transform.invert(pred_ds) + pred_ds = pred_ds.rename({"target_pr": "pred_pr"}) -@app.command() -@Timer(name="sample", text="{name}: {minutes:.1f} minutes", logger=logging.info) -@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") -def sample( - workdir: Path, - dataset: str = typer.Option(...), - split: str = "val", - checkpoint: str = typer.Option(...), - batch_size: int = None, - num_samples: int = 1, - input_transform_dataset: str = None, - input_transform_key: str = None, - ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, -): + return pred_ds - config_path = os.path.join(workdir, "config.yml") - config = load_config(config_path) - if batch_size is not None: - config.eval.batch_size = batch_size - with config.unlocked(): - if input_transform_dataset is not None: - config.data.input_transform_dataset = input_transform_dataset - else: - config.data.input_transform_dataset = dataset - if input_transform_key is not None: - config.data.input_transform_key = input_transform_key - - output_dirpath = samples_path( - workdir=workdir, - checkpoint=checkpoint, - dataset=dataset, - input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", - split=split, - ensemble_member=ensemble_member, - ) - os.makedirs(output_dirpath, exist_ok=True) - - transform_dir = os.path.join(workdir, "transforms") - - eval_dl, _, target_transform = get_dataloader( - dataset, - config.data.dataset_name, - config.data.input_transform_dataset, - config.data.input_transform_key, - config.data.target_transform_key, - transform_dir, - split=split, - ensemble_members=[ensemble_member], - include_time_inputs=config.data.time_inputs, - evaluation=True, - batch_size=config.eval.batch_size, - shuffle=False, +def _sample_id(variable: str, eval_ds: xr.Dataset) -> xr.Dataset: + """Create a Dataset of pr samples set to the values the given variable from the dataset.""" + cf_data_vars = { + key: eval_ds.data_vars[key] + for key in [ + "rotated_latitude_longitude", + "time_bnds", + "grid_latitude_bnds", + "grid_longitude_bnds", + ] + if key in eval_ds.variables + } + coords = eval_ds.coords + np_samples = eval_ds[variable].data + xr_samples = _np_samples_to_xr( + np_samples, coords=coords, target_transform=None, cf_data_vars=cf_data_vars ) - ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth") - num_predictors = eval_dl.dataset[0][0].shape[0] - state = load_model(config, num_predictors, ckpt_filename) - - for sample_id in range(num_samples): - typer.echo(f"Sample run {sample_id}...") - xr_samples = sampling.sample(state["model"], eval_dl, target_transform) - - output_filepath = output_dirpath / f"predictions-{shortuuid.uuid()}.nc" - - logger.info(f"Saving predictions to {output_filepath}") - xr_samples.to_netcdf(output_filepath) + return xr_samples @app.command() @@ -128,7 +77,7 @@ def sample_id( dataset: str = typer.Option(...), variable: str = "pr", split: str = "val", - ensemble_member: str = "01", + ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): output_dirpath = samples_path( @@ -144,7 +93,7 @@ def sample_id( eval_ds = load_raw_dataset_split(dataset, split).sel( ensemble_member=[ensemble_member] ) - xr_samples = sampling.sample_id(variable, eval_ds) + xr_samples = _sample_id(variable, eval_ds) output_filepath = os.path.join(output_dirpath, f"predictions-{shortuuid.uuid()}.nc") diff --git a/src/ml_downscaling_emulator/deterministic/__init__.py b/src/ml_downscaling_emulator/deterministic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py b/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py deleted file mode 100644 index 3cf49481..00000000 --- a/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py +++ /dev/null @@ -1,38 +0,0 @@ -import ml_collections -import torch - - -def get_config(): - config = ml_collections.ConfigDict() - - config.training = training = ml_collections.ConfigDict() - training.n_epochs = 100 - training.batch_size = 64 - training.snapshot_freq = 25 - training.log_freq = 50 - training.eval_freq = 1000 - - config.eval = evaluate = ml_collections.ConfigDict() - evaluate.batch_size = 64 - - config.data = data = ml_collections.ConfigDict() - data.dataset_name = "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr" - data.input_transform_key = "stan" - data.target_transform_key = "sqrturrecen" - data.input_transform_dataset = None - data.time_inputs = False - - config.model = model = ml_collections.ConfigDict() - model.name = "u-net" - model.loss = "MSELoss" - - config.optim = optim = ml_collections.ConfigDict() - optim.optimizer = "Adam" - optim.lr = 2e-4 - - config.seed = 42 - config.device = ( - torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") - ) - - return config diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py deleted file mode 100644 index 4f0b3c04..00000000 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ /dev/null @@ -1,262 +0,0 @@ -import logging -import os - -from absl import flags -from codetiming import Timer -import numpy as np -import torch -from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm -import yaml - -from mlde_utils import DatasetMetadata - -from ..training import log_epoch, track_run -from .utils import restore_checkpoint, save_checkpoint, create_model -from ..data import get_dataloader - -FLAGS = flags.FLAGS -EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME") - -logging.basicConfig( - level=logging.INFO, - format="%(levelname)s - %(filename)s - %(asctime)s - %(message)s", -) -logger = logging.getLogger() -logger.setLevel("INFO") - - -def val_loss(config, val_dl, eval_step_fn, state): - val_set_loss = 0.0 - for val_cond_batch, val_x_batch, val_time_batch in val_dl: - val_x_batch = val_x_batch.to(config.device) - val_cond_batch = val_cond_batch.to(config.device) - - val_batch_loss = eval_step_fn(state, val_x_batch, val_cond_batch) - - # Progress - val_set_loss += val_batch_loss.item() - val_set_loss = val_set_loss / len(val_dl) - - return val_set_loss - - -@Timer(name="train", text="{name}: {minutes:.1f} minutes", logger=logging.info) -def train(config, workdir): - os.makedirs(workdir, exist_ok=True) - - gfile_stream = open(os.path.join(workdir, "stdout.txt"), "w") - handler = logging.StreamHandler(gfile_stream) - formatter = logging.Formatter( - "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" - ) - handler.setFormatter(formatter) - logger.addHandler(handler) - - # Create transform saving directory - transform_dir = os.path.join(workdir, "transforms") - os.makedirs(transform_dir, exist_ok=True) - - # Create directories for experimental logs - sample_dir = os.path.join(workdir, "samples") - os.makedirs(sample_dir, exist_ok=True) - - tb_dir = os.path.join(workdir, "tensorboard") - os.makedirs(tb_dir, exist_ok=True) - - logging.info(f"Starting {os.path.basename(__file__)}") - - # Create checkpoints directory - checkpoint_dir = os.path.join(workdir, "checkpoints") - os.makedirs(checkpoint_dir, exist_ok=True) - # Intermediate checkpoints to resume training after pre-emption in cloud environments - checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") - os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) - - dataset_meta = DatasetMetadata(config.data.dataset_name) - - # Build dataloaders - train_dl, _, _ = get_dataloader( - config.data.dataset_name, - config.data.dataset_name, - config.data.dataset_name, - config.data.input_transform_key, - config.data.target_transform_key, - transform_dir, - batch_size=config.training.batch_size, - split="train", - ensemble_members=dataset_meta.ensemble_members(), - include_time_inputs=config.data.time_inputs, - evaluation=False, - ) - val_dl, _, _ = get_dataloader( - config.data.dataset_name, - config.data.dataset_name, - config.data.dataset_name, - config.data.input_transform_key, - config.data.target_transform_key, - transform_dir, - batch_size=config.training.batch_size, - split="val", - ensemble_members=dataset_meta.ensemble_members(), - include_time_inputs=config.data.time_inputs, - evaluation=False, - ) - - # Setup model, loss and optimiser - num_predictors = train_dl.dataset[0][0].shape[0] - model = torch.nn.DataParallel( - create_model(config, num_predictors).to(device=config.device) - ) - - if config.model.loss == "MSELoss": - criterion = torch.nn.MSELoss().to(config.device) - else: - raise NotImplementedError(f"Loss {config.model.loss} not supported yet!") - - if config.optim.optimizer == "Adam": - optimizer = torch.optim.Adam(model.parameters(), lr=config.optim.lr) - else: - raise NotImplementedError( - f"Optimizer {config.optim.optimizer} not supported yet!" - ) - - state = dict(optimizer=optimizer, model=model, step=0, epoch=0) - # Resume training when intermediate checkpoints are detected - state, _ = restore_checkpoint(checkpoint_meta_dir, state, config.device) - initial_epoch = ( - int(state["epoch"]) + 1 - ) # start from the epoch after the one currently reached - - initial_epoch = ( - int(state["epoch"]) + 1 - ) # start from the epoch after the one currently reached - # step = state["step"] - - def loss_fn(model, batch, cond): - return criterion(model(cond), batch) - - def optimize_fn(optimizer, params, step, lr, warmup=5000, grad_clip=1.0): - """Optimizes with warmup and gradient clipping (disabled if negative).""" - if warmup > 0: - for g in optimizer.param_groups: - g["lr"] = lr * np.minimum(step / warmup, 1.0) - if grad_clip >= 0: - torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) - optimizer.step() - - # Compute validation loss - def eval_step_fn(state, batch, cond): - """Running one step of training or evaluation. - - Args: - state: A dictionary of training information, containing the score model, optimizer, - EMA status, and number of optimization steps. - batch: A mini-batch of training/evaluation data to model. - cond: A mini-batch of conditioning inputs. - - Returns: - loss: The average loss value of this state. - """ - model = state["model"] - with torch.no_grad(): - loss = loss_fn(model, batch, cond) - - return loss - - def train_step_fn(state, batch, cond): - """Running one step of training or evaluation. - - Args: - state: A dictionary of training information, containing the score model, optimizer, - EMA status, and number of optimization steps. - batch: A mini-batch of training/evaluation data to model. - cond: A mini-batch of conditioning inputs. - - Returns: - loss: The average loss value of this state. - """ - model = state["model"] - optimizer = state["optimizer"] - optimizer.zero_grad() - loss = loss_fn(model, batch, cond) - loss.backward() - optimize_fn( - optimizer, model.parameters(), step=state["step"], lr=config.optim.lr - ) - state["step"] += 1 - - return loss - - # save the config - config_path = os.path.join(workdir, "config.yml") - with open(config_path, "w") as f: - yaml.dump(config, f) - - run_name = os.path.basename(workdir) - run_config = dict( - dataset=config.data.dataset_name, - input_transform_key=config.data.input_transform_key, - target_transform_key=config.data.target_transform_key, - architecture=config.model.name, - name=run_name, - loss=config.model.loss, - time_inputs=config.data.time_inputs, - ) - - with track_run( - EXPERIMENT_NAME, run_name, run_config, [config.model.name, "baseline"], tb_dir - ) as (wandb_run, tb_writer): - # Fit model - - logging.info("Starting training loop at epoch %d." % (initial_epoch,)) - - for epoch in range(initial_epoch, config.training.n_epochs + 1): - state["epoch"] = epoch - # Update model based on training data - model.train() - - train_set_loss = 0.0 - with logging_redirect_tqdm(): - with tqdm( - total=len(train_dl.dataset), - desc=f"Epoch {state['epoch']}", - unit=" timesteps", - ) as pbar: - for (cond_batch, x_batch, time_batch) in train_dl: - cond_batch = cond_batch.to(config.device) - x_batch = x_batch.to(config.device) - - train_batch_loss = train_step_fn(state, x_batch, cond_batch) - train_set_loss += train_batch_loss.item() - - # Log progress so far on epoch - pbar.update(cond_batch.shape[0]) - - train_set_loss = train_set_loss / len(train_dl) - - # Save a temporary checkpoint to resume training after each epoch - save_checkpoint(checkpoint_meta_dir, state) - - # Report the loss on an validation dataset each epoch - model.eval() - val_set_loss = val_loss(config, val_dl, eval_step_fn, state) - epoch_metrics = { - "epoch/train/loss": train_set_loss, - "epoch/val/loss": val_set_loss, - } - log_epoch(state["epoch"], epoch_metrics, wandb_run, tb_writer) - # Checkpoint model - if ( - state["epoch"] != 0 - and state["epoch"] % config.training.snapshot_freq == 0 - ) or state["epoch"] == config.training.n_epochs: - checkpoint_path = os.path.join( - checkpoint_dir, f"epoch_{state['epoch']}.pth" - ) - save_checkpoint(checkpoint_path, state) - logging.info( - f"epoch: {state['epoch']}, checkpoint saved to {checkpoint_path}" - ) - - logging.info(f"Finished {os.path.basename(__file__)}") diff --git a/src/ml_downscaling_emulator/deterministic/sampling.py b/src/ml_downscaling_emulator/deterministic/sampling.py deleted file mode 100644 index f5737719..00000000 --- a/src/ml_downscaling_emulator/deterministic/sampling.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -import torch -from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm -import xarray as xr - - -def generate_np_samples(model, cond_batch): - model.eval() - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - cond_batch = cond_batch.to(device) - - samples = model(cond_batch) - # drop the feature channel dimension (only have target pr as output) - samples = samples.squeeze(dim=1) - # extract numpy array - samples = samples.cpu().detach().numpy() - return samples - - -def np_samples_to_xr(np_samples, coords, target_transform, cf_data_vars): - coords = {**dict(coords)} - - pred_pr_dims = ["ensemble_member", "time", "grid_latitude", "grid_longitude"] - pred_pr_attrs = { - "grid_mapping": "rotated_latitude_longitude", - "standard_name": "pred_pr", - "units": "kg m-2 s-1", - } - pred_pr_var = (pred_pr_dims, np_samples, pred_pr_attrs) - - data_vars = {**cf_data_vars, "target_pr": pred_pr_var} - - pred_ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs={}) - - if target_transform is not None: - pred_ds = target_transform.invert(pred_ds) - - pred_ds = pred_ds.rename({"target_pr": "pred_pr"}) - - return pred_ds - - -def sample(model, eval_dl, target_transform): - cf_data_vars = { - key: eval_dl.dataset.ds.data_vars[key] - for key in [ - "rotated_latitude_longitude", - "time_bnds", - "grid_latitude_bnds", - "grid_longitude_bnds", - ] - } - preds = [] - with logging_redirect_tqdm(): - with tqdm( - total=len(eval_dl.dataset), desc=f"Sampling", unit=" timesteps" - ) as pbar: - with torch.no_grad(): - for cond_batch, _, time_batch in eval_dl: - coords = eval_dl.dataset.ds.sel(time=time_batch).coords - batch_np_samples = generate_np_samples(model, cond_batch) - # add ensemble member axis to np samples - batch_np_samples = batch_np_samples[np.newaxis, :] - xr_samples = np_samples_to_xr( - batch_np_samples, coords, target_transform, cf_data_vars - ) - preds.append(xr_samples) - - pbar.update(cond_batch.shape[0]) - - ds = xr.combine_by_coords( - preds, - compat="no_conflicts", - combine_attrs="drop_conflicts", - coords="all", - join="inner", - data_vars="all", - ) - - return ds - - -def sample_id(variable: str, eval_ds: xr.Dataset) -> xr.Dataset: - """Create a Dataset of pr samples set to the values the given variable from the dataset.""" - cf_data_vars = { - key: eval_ds.data_vars[key] - for key in [ - "rotated_latitude_longitude", - "time_bnds", - "grid_latitude_bnds", - "grid_longitude_bnds", - ] - if key in eval_ds.variables - } - coords = eval_ds.coords - np_samples = eval_ds[variable].data - xr_samples = np_samples_to_xr( - np_samples, coords=coords, target_transform=None, cf_data_vars=cf_data_vars - ) - - return xr_samples diff --git a/src/ml_downscaling_emulator/deterministic/utils.py b/src/ml_downscaling_emulator/deterministic/utils.py deleted file mode 100644 index 5ee423c6..00000000 --- a/src/ml_downscaling_emulator/deterministic/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -import os -import torch.nn as nn - -from ..unet import unet - - -def create_model(config, num_predictors): - if config.model.name == "u-net": - return unet.UNet(num_predictors, 1) - if config.model.name == "debug": - return nn.Conv2d(num_predictors, 1, 3, stride=1, padding=1) - raise NotImplementedError(f"Model {config.model.name} not supported yet!") - - -def restore_checkpoint(ckpt_dir, state, device): - import torch - - if not os.path.exists(ckpt_dir): - os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) - logging.warning( - f"No checkpoint found at {ckpt_dir}." f"Returned the same state as input" - ) - return state, False - else: - loaded_state = torch.load(ckpt_dir, map_location=device) - state["optimizer"].load_state_dict(loaded_state["optimizer"]) - state["model"].load_state_dict(loaded_state["model"], strict=False) - state["step"] = loaded_state["step"] - state["epoch"] = loaded_state["epoch"] - logging.info( - f"Checkpoint found at {ckpt_dir}. " - f"Returned the state from {state['epoch']}/{state['step']}" - ) - return state, True - - -def save_checkpoint(ckpt_dir, state): - import torch - - saved_state = { - "optimizer": state["optimizer"].state_dict(), - "model": state["model"].state_dict(), - "step": state["step"], - "epoch": state["epoch"], - } - torch.save(saved_state, ckpt_dir)