Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iluise/head #24

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions atmorep/core/atmorep_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,18 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non
self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train,
cf.batch_size,
pre_batch, cf.n_size, cf.num_samples_per_epoch,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True )
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params,
sampler = None)

self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val,
cf.batch_size_validation,
pre_batch, cf.n_size, cf.num_samples_validate,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True )
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params,
sampler = None)

Expand Down
15 changes: 7 additions & 8 deletions atmorep/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ def parse_args( cf, args) :
@staticmethod
def run( cf, model_id, model_epoch, devices) :

if not hasattr(cf, 'batch_size'):
cf.batch_size = cf.batch_size_max
if not hasattr(cf, 'batch_size_validation'):
cf.batch_size_validation = cf.batch_size_max

cf.with_mixed_precision = True

# set/over-write options as desired
Expand All @@ -82,7 +77,7 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) :
else :
num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] ))
devices = init_torch( num_accs_per_task)
#devices = ['cuda']
#devices = ['cuda:1']

par_rank, par_size = setup_ddp( with_ddp)
cf = Config().load_json( model_id)
Expand Down Expand Up @@ -112,6 +107,12 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) :
cf.with_mixed_precision = False
if not hasattr(cf, 'with_pytest'):
cf.with_pytest = False
if not hasattr(cf, 'batch_size'):
cf.batch_size = cf.batch_size_max
if not hasattr(cf, 'batch_size_validation'):
cf.batch_size_validation = cf.batch_size_max
if not hasattr(cf, 'years_val'):
cf.years_val = cf.years_test

func = getattr( Evaluator, mode)
func( cf, model_id, model_epoch, devices, args)
Expand Down Expand Up @@ -159,8 +160,6 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) :
cf.batch_size = 196 #14
if not hasattr(cf, 'batch_size_validation'):
cf.batch_size_validation = 1 #64
if not hasattr(cf, 'batch_size_delta'):
cf.batch_size_delta = 8
if not hasattr(cf, 'num_samples_validate'):
cf.num_samples_validate = 196
#if not hasattr(cf,'with_mixed_precision'):
Expand Down
60 changes: 30 additions & 30 deletions atmorep/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
####################################################################################################

import torch
import numpy as np
import os
import sys
import pdb
import traceback

import pdb
import wandb

import atmorep.config.config as config
from atmorep.core.trainer import Trainer_BERT
from atmorep.utils.utils import Config
from atmorep.utils.utils import setup_ddp
from atmorep.utils.utils import setup_wandb
from atmorep.utils.utils import init_torch
import atmorep.utils.utils as utils


####################################################################################################
Expand Down Expand Up @@ -110,40 +106,43 @@ def train() :
# [ total masking rate, rate masking, rate noising, rate for multi-res distortion]
# ]

cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ],
[ 96, 105, 114, 123, 137 ],
[12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]
# cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ]
# cf.fields_prediction = [ [cf.fields[0][0], 1.] ]

cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ],
[ 96, 105, 114, 123, 137 ],
[12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ]

cf.fields_prediction = [ [cf.fields[0][0], 1.] ]

# cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ] ]

# cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ],

# cf.fields = [ [ 'velocity_v', [ 1, 1024, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]
# [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ]

# cf.fields = [ [ 'velocity_z', [ 1, 1024, [ ], 0 ],
# cf.fields = [ [ 'velocity_z', [ 1, 512, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]

# cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ],
# cf.fields = [ [ 'specific_humidity', [ 1, 1024, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ]
# [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ]

cf.fields_targets = []

cf.years_train = [2021] # list( range( 1980, 2018))
cf.years_train = list( range( 2010, 2021))
cf.years_val = [2021] #[2018]
cf.month = None
cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]]
cf.time_sampling = 1 # sampling rate for time steps
# random seeds
cf.torch_seed = torch.initial_seed()
# training params
cf.batch_size_validation = 64
cf.batch_size = 32
cf.batch_size_validation = 1 #64
cf.batch_size = 96
cf.num_epochs = 128
cf.num_samples_per_epoch = 4096*12
cf.num_samples_validate = 128*12
Expand All @@ -161,12 +160,12 @@ def train() :
cf.dropout_rate = 0.05
cf.with_qk_lnorm = False
# encoder
cf.encoder_num_layers = 4
cf.encoder_num_layers = 6
cf.encoder_num_heads = 16
cf.encoder_num_mlp_layers = 2
cf.encoder_att_type = 'dense'
# decoder
cf.decoder_num_layers = 4
cf.decoder_num_layers = 6
cf.decoder_num_heads = 16
cf.decoder_num_mlp_layers = 2
cf.decoder_self_att = False
Expand All @@ -177,19 +176,19 @@ def train() :
cf.net_tail_num_nets = 16
cf.net_tail_num_layers = 0
# loss
cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps
cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps, weighted_mse
# training
cf.optimizer_zero = False
cf.lr_start = 5. * 10e-7
cf.lr_max = 0.00005
cf.lr_min = 0.00002
cf.weight_decay = 0.1
cf.lr_max = 0.00005*3
cf.lr_min = 0.00004 #0.00002
cf.weight_decay = 0.05 #0.1
cf.lr_decay_rate = 1.025
cf.lr_start_epochs = 3
# BERT
# strategies: 'BERT', 'forecast', 'temporal_interpolation'
cf.BERT_strategy = 'BERT'
cf.forecast_num_tokens = 1 # only needed / used for BERT_strategy 'forecast
cf.forecast_num_tokens = 2 # only needed / used for BERT_strategy 'forecast
cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields
# (fields need to have same BERT params for this to have effect)
cf.BERT_mr_max = 2 # maximum reduction rate for resolution
Expand Down Expand Up @@ -219,12 +218,13 @@ def train() :
# # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr'
# # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr'
# # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr'
cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/'
# # # in steps x lat_degrees x lon_degrees
# cf.n_size = [36, 0.25*9*6, 0.25*9*12]
cf.n_size = [36, 0.25*9*6, 0.25*9*12]

# cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr'
cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr'
cf.n_size = [36, 1*9*6, 1.*9*12]
#cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr'
#cf.n_size = [36, 1*9*6, 1.*9*12]

if cf.with_wandb and 0 == cf.par_rank :
cf.write_json( wandb)
Expand Down
47 changes: 24 additions & 23 deletions atmorep/datasets/multifield_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MultifieldDataSampler( torch.utils.data.IterableDataset):

###################################################
def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size,
num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False,
num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False, compute_weights = False,
fields_targets = None, pre_batch_targets = None ) :
'''
Data set for single dynamic field at an arbitrary number of vertical levels
Expand All @@ -46,6 +46,7 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size,
self.n_size = n_size
self.num_samples = num_samples
self.with_source_idxs = with_source_idxs
self.compute_weights = compute_weights
self.with_shuffle = with_shuffle
self.pre_batch = pre_batch

Expand Down Expand Up @@ -185,11 +186,11 @@ def __iter__(self):

source_data, tok_info = [], []
# extract data, normalize and tokenize
cdata = data_t[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]
cdata = data_t[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]

normalizer = self.normalizers[ifield][ilevel]
if corr_type != 'global':
normalizer = normalizer[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]
if corr_type != 'global':
normalizer = normalizer[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]
cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base)

source_data = tokenize( torch.from_numpy( cdata), tok_size )
Expand Down Expand Up @@ -217,29 +218,29 @@ def __iter__(self):

tmidx_list = sources[-1]
weights_idx_list = []
if self.compute_weights:
for ifield, field_info in enumerate(self.fields):
weights = []
for ilevel, vl in enumerate(field_info[2]):
for ibatch in range(self.batch_size):

lats_idx = source_idxs[ibatch][1]
lons_idx = source_idxs[ibatch][2]

for ifield, field_info in enumerate(self.fields):
weights = []
for ilevel, vl in enumerate(field_info[2]):
for ibatch in range(self.batch_size):

lats_idx = source_idxs[ibatch][1]
lons_idx = source_idxs[ibatch][2]

idx_base = tmidx_list[ifield][ilevel][ibatch]
idx_loc = idx_base - np.prod(num_tokens) * ibatch

grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1
grid = torch.from_numpy( np.array( np.broadcast_to( grid,
shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1))
idx_base = tmidx_list[ifield][ilevel][ibatch]
idx_loc = idx_base - np.prod(num_tokens) * ibatch

grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1
grid = torch.from_numpy( np.array( np.broadcast_to( grid,
shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1))
iluise marked this conversation as resolved.
Show resolved Hide resolved

grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2)
grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2)

lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()])
lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()])

weights.append([get_weights(la) for la in lats_mskd_b])
weights.append([get_weights(la) for la in lats_mskd_b])

weights_idx_list.append(weights)
weights_idx_list.append(weights)
sources = (*sources, weights_idx_list)

# TODO: implement (only required when prediction target comes from different data stream)
Expand Down