Skip to content

Commit

Permalink
23 plot a validation sample after long rollout (ecmwf#26)
Browse files Browse the repository at this point in the history
* feature: long rollout plots

* incorporate review from lorenzo, correction for ocean variables to not consider them as pressure level variables and small fix for grouping of less than 15 variables in loss contribution histogram

* backward compatibility for config files without longrolloutplot configuration

Reviewers: @lzampier , @theissenhelen , @JesperDramsch
  • Loading branch information
sahahner authored Sep 23, 2024
1 parent e9c5b55 commit a599dfd
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Keep it human-readable, your future self will thank you!
- Fix: Inference checkpoints are now saved according the frequency settings defined in the config [#37](https://github.com/ecmwf/anemoi-training/pull/37)
- Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50)
- Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48)
- Long Rollout Plots

### Fixed

Expand All @@ -47,7 +48,6 @@ Keep it human-readable, your future self will thank you!
- Subcommand for checkpoint handling

#### Functionality

- Searchpaths for Hydra configs, to enable configs in CWD, `ANEMOI_CONFIG_PATH` env, and `.config/anemoi/training` in addition to package defaults
- MlFlow token authentication
- Configurable pressure level scaling
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/training/config/diagnostics/eval_rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ plot:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
learned_features: False
longrollout:
enabled: False
rollout: [60]
frequency: 20 # every X epochs

debug:
# this will detect and trace back NaNs / Infs etc. but will slow down training
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def ds_valid(self) -> NativeGridDataset:
r = self.rollout
if self.config.diagnostics.eval.enabled:
r = max(r, self.config.diagnostics.eval.rollout)
if self.config.diagnostics.plot.get("longrollout") and self.config.diagnostics.plot.longrollout.enabled:
r = max(r, max(self.config.diagnostics.plot.longrollout.rollout))
assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
f"validation start date {self.config.dataloader.validation.start}"
Expand Down
139 changes: 136 additions & 3 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,137 @@ def on_validation_batch_end(
self._eval(pl_module, batch)


class LongRolloutPlots(BasePlotCallback):
"""Evaluates the model performance over a (longer) rollout window."""

def __init__(self, config) -> None:
"""Initialize RolloutEval callback.
Parameters
----------
config : dict
Dictionary with configuration settings
"""
super().__init__(config)

LOGGER.debug(
"Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...",
config.diagnostics.plot.longrollout.rollout,
config.diagnostics.plot.longrollout.frequency,
)
self.rollout = config.diagnostics.plot.longrollout.rollout
self.eval_frequency = config.diagnostics.plot.longrollout.frequency
self.sample_idx = self.config.diagnostics.plot.sample_idx

@rank_zero_only
def _plot(
self,
trainer,
pl_module: pl.LightningModule,
batch: torch.Tensor,
batch_idx,
epoch,
) -> None:

start_time = time.time()

logger = trainer.logger

# Build dictionary of inidicies and parameters to be plotted
plot_parameters_dict = {
pl_module.data_indices.model.output.name_to_index[name]: (
name,
name not in self.config.data.get("diagnostic", []),
)
for name in self.config.diagnostics.plot.parameters
}

if self.post_processors is None:
# Copy to be used across all the training cycle
self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu()
if self.latlons is None:
self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy())
local_rank = pl_module.local_rank

batch = pl_module.model.pre_processors(batch, in_place=False)
# prepare input tensor for rollout from preprocessed batch
x = batch[
:,
0 : pl_module.multi_step,
...,
pl_module.data_indices.internal_data.input.full,
] # (bs, multi_step, latlon, nvar)
assert (
batch.shape[1] >= max(self.rollout) + pl_module.multi_step
), "Batch length not sufficient for requested rollout length!"

# prepare input tensor for plotting
input_tensor_0 = batch[
self.sample_idx,
pl_module.multi_step - 1,
...,
pl_module.data_indices.internal_data.output.full,
].cpu()
data_0 = self.post_processors(input_tensor_0).numpy()

# start rollout
with torch.no_grad():
for rollout_step in range(max(self.rollout)):
y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar)

x = pl_module.advance_input(x, y_pred, batch, rollout_step)

if (rollout_step + 1) in self.rollout:
# prepare true output tensor for plotting
input_tensor_rollout_step = batch[
self.sample_idx,
pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1)
...,
pl_module.data_indices.internal_data.output.full,
].cpu()
data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy()

# prepare predicted output tensor for plotting
output_tensor = self.post_processors(
y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu()
).numpy()

fig = plot_predicted_multilevel_flat_sample(
plot_parameters_dict,
self.config.diagnostics.plot.per_sample,
self.latlons,
self.config.diagnostics.plot.get("accumulation_levels_plot", None),
self.config.diagnostics.plot.get("cmap_accumulation", None),
data_0.squeeze(),
data_rollout_step.squeeze(),
output_tensor[0, 0, :, :], # rolloutstep, first member
# force_global_view=self.show_entire_globe,
)

self._output_figure(
logger,
fig,
epoch=epoch,
tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0",
exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}",
)
LOGGER.info(f"Time taken to plot samples after longer rollout: {int(time.time() - start_time)} seconds")

@rank_zero_only
def on_validation_batch_end(self, trainer, pl_module, output, batch, batch_idx) -> None:
if (batch_idx) % self.plot_frequency == 0 and (trainer.current_epoch + 1) % self.eval_frequency == 0:
precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
}
prec = trainer.precision
dtype = precision_mapping.get(prec)
context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext()

with context:
self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch)


class GraphTrainableFeaturesPlot(BasePlotCallback):
"""Visualize the trainable features defined at the data and hidden graph nodes.
Expand Down Expand Up @@ -420,15 +551,15 @@ def automatically_determine_group(name: str) -> str:
return_counts=True,
)

LOGGER.info("Order of parameters in loss histogram: %s", sorted_parameter_names)

# get a color per group and project to parameter list
cmap = "tab10" if len(unique_group_list) <= 10 else "tab20"
if len(unique_group_list) > 20:
LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.")
# if all groups have count 1 use black color
bar_color_per_group = (
"k" if not np.any(group_counts - 1) else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list)))
np.tile("k", len(group_counts))
if not np.any(group_counts - 1)
else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list)))
)

# set x-ticks
Expand Down Expand Up @@ -980,6 +1111,8 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901
)
if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None:
trainer_callbacks.extend([PlotAdditionalMetrics(config)])
if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled:
trainer_callbacks.extend([LongRolloutPlots(config)])

if config.training.swa.enabled:
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
Expand Down

0 comments on commit a599dfd

Please sign in to comment.