From ed65f3983231398d00873d0708c83cfe44670e6f Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 8 Mar 2024 19:27:52 +0100 Subject: [PATCH 01/66] - Added support for separate data directories for separate datasets. - Fixed two smaller issues (issues #8 and #9) --- atmorep/config/config.py | 25 +++++++++++++++++-- atmorep/core/evaluate.py | 2 +- atmorep/datasets/data_loader.py | 31 +++++++++++++----------- atmorep/datasets/dynamic_field_level.py | 2 +- atmorep/datasets/normalizer_global.py | 4 ++-- atmorep/datasets/normalizer_local.py | 32 +++++++++++++++++++------ atmorep/utils/utils.py | 12 ++++++---- 7 files changed, 77 insertions(+), 31 deletions(-) diff --git a/atmorep/config/config.py b/atmorep/config/config.py index 53755f2..d582d30 100644 --- a/atmorep/config/config.py +++ b/atmorep/config/config.py @@ -4,10 +4,10 @@ fpath = os.path.dirname(os.path.realpath(__file__)) year_base = 1979 -year_last = 2022 +year_last = 2021 path_models = Path( fpath, '../../models/') -path_results = Path( fpath, '../../results/') +path_results = Path( fpath, '../../results') path_data = Path( fpath, '../../data/') path_plots = Path( fpath, '../results/plots/') @@ -17,3 +17,24 @@ 'velocity_u' : 'u', 'velocity_v': 'v', 'velocity_z' : 'w', 'total_precip' : 'tp', 'radar_precip' : 'yw_hourly', 't2m' : 't_2m', 'u_10m' : 'u_10m', 'v_10m' : 'v_10m', } + +# TODO: extract this info from the datasets +datasets = {} +# +datasets['era5'] = {} +datasets['era5']['resolution'] = [1, 0.25, 0.25] +datasets['era5']['extent'] = [ [1979, 2022], [90., -90], [0.0, 360] ] +datasets['era5']['is_global'] = True +datasets['era5']['file_size'] = [ -1, 721, 1440] +# +datasets['cosmo_rea6'] = {} +datasets['cosmo_rea6']['resolution'] = [1, 0.0625, 0.0625] +datasets['cosmo_rea6']['extent'] = [ [1997, 2017], [27.5,70.25], [-12.5,37.0] ] +datasets['cosmo_rea6']['is_global'] = False +datasets['cosmo_rea6']['file_size'] = [ -1, 685, 793] +# +datasets['cerra'] = {} +datasets['cerra']['resolution'] = [3, 0.25, 0.25] +datasets['cerra']['extent'] = [ [1985, 2001], [75.25,20.5], [-58.0,74.0] ] +datasets['cerra']['is_global'] = False +datasets['cerra']['file_size'] = [ -1, 220, 529] diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 3033040..b8178f2 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -30,7 +30,7 @@ # model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence - # model_id = '1jh2qvrx' # multiformer, velocity + model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity # model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting diff --git a/atmorep/datasets/data_loader.py b/atmorep/datasets/data_loader.py index 4ddeae3..9c2638f 100644 --- a/atmorep/datasets/data_loader.py +++ b/atmorep/datasets/data_loader.py @@ -20,6 +20,7 @@ import xarray as xr from functools import partial +import atmorep.config.config as config import atmorep.utils.utils as utils from atmorep.config.config import year_base from atmorep.utils.utils import tokenize @@ -30,19 +31,22 @@ class DataLoader: - def __init__(self, path, file_shape, data_type = 'reanalysis', + def __init__(self, path, file_shape, data_type, field_info, file_format = 'grib', level_type = 'pl', - fname_base = '{}/{}/{}{}/{}_{}_y{}_m{}_{}{}', + fname_base = '{}/{}/{}/{}{}/{}_{}_y{}_m{}_{}{}', smoothing = 0, - log_transform = False): + log_transform = False, + partial_load = 0): self.path = path self.data_type = data_type + self.field_info = field_info self.file_format = file_format self.file_shape = file_shape self.fname_base = fname_base self.smoothing = smoothing self.log_transform = log_transform + self.partial_load = partial_load if 'grib' == file_format : self.file_ext = '.grib' @@ -58,16 +62,11 @@ def __init__(self, path, file_shape, data_type = 'reanalysis', self.file_loader = netcdf_file_loader else : raise ValueError('Unsupported file format.') - - self.fname_base = fname_base + self.file_ext - self.grib_index = { 'vorticity' : 'vo', 'divergence' : 'd', 'geopotential' : 'z', - 'orography' : 'z', 'temperature': 't', 'specific_humidity' : 'q', - 'mean_top_net_long_wave_radiation_flux' : 'mtnlwrf', - 'velocity_u' : 'u', 'velocity_v': 'v', 'velocity_z' : 'w', - 'total_precip' : 'tp', 'radar_precip' : 'yw_hourly', - 't2m' : 't_2m', 'u_10m' : 'u_10m', 'v_10m' : 'v_10m', } + self.fname_base = fname_base + self.file_ext + self.grib_index = config.grib_index + def get_field( self, year, month, field, level_type, vl, token_size = [-1, -1], t_pad = [-1, -1, 1]): @@ -75,8 +74,7 @@ def get_field( self, year, month, field, level_type, vl, data_ym = torch.zeros( (0, self.file_shape[1], self.file_shape[2])) # pre-fill fixed values - # fname_base = self.fname_base.format( self.path, self.data_type, field, level_type, vl, - fname_base = self.fname_base.format( self.path, field, level_type, vl, + fname_base = self.fname_base.format( self.path, self.data_type, field, level_type, vl, self.data_type, field, {},{},{},{}) # padding pre @@ -98,7 +96,7 @@ def get_field( self, year, month, field, level_type, vl, # data fname = fname_base.format( year, str(month).zfill(2), level_type, vl) days_month = utils.days_in_month( year, month) - x = self.file_loader(fname, self.grib_index[field], [0, 0, t_srate], days_month) + x = self.file_loader(fname,self.grib_index[field], [0, self.partial_load, t_srate],days_month) data_ym = torch.cat((data_ym,x),0) @@ -137,7 +135,12 @@ def get_single_field( self, years_months, field = 'vorticity', level_type = 'pl' token_size = [-1, -1], t_pad = [-1, -1, 1]): data_field = [] + extent_t = config.datasets[self.data_type]['extent'][0] for year, month in years_months : + # skip loading when the year is not available for the dataset + if year < extent_t[0] or year > extent_t[1] : + data_field.append( []) + continue data_field.append( self.get_field( year, month, field, level_type, vl, token_size, t_pad)) return data_field diff --git a/atmorep/datasets/dynamic_field_level.py b/atmorep/datasets/dynamic_field_level.py index 600a181..0a55d99 100644 --- a/atmorep/datasets/dynamic_field_level.py +++ b/atmorep/datasets/dynamic_field_level.py @@ -95,7 +95,7 @@ def __init__( self, file_path, years_data, field_info, else : self.normalizer = NormalizerLocal( field_info, vl, self.file_shape, data_type) - self.loader = DataLoader( self.file_path, self.file_shape, data_type, + self.loader = DataLoader( self.file_path, self.file_shape, data_type, field_info, file_format = self.file_format, level_type = self.level_type, smoothing = self.smoothing, log_transform=self.log_transform_data) diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py index 2a844ce..8bf997a 100644 --- a/atmorep/datasets/normalizer_global.py +++ b/atmorep/datasets/normalizer_global.py @@ -23,10 +23,10 @@ class NormalizerGlobal() : def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : # TODO: use path from config and pathlib.Path() - fname_base = '{}/normalization/{}/global_normalization_mean_var_{}_{}{}.bin' + fname_base = '{}/{}/normalization/{}/global_normalization_mean_var_{}_{}{}.bin' fn = field_info[0] - corr_fname = fname_base.format( str(config.path_data), fn, fn, level_type, vlevel) + corr_fname = fname_base.format( str(config.path_data), data_type, fn, fn, level_type, vlevel) self.corr_data = np.fromfile(corr_fname, dtype=np.float32).reshape( (-1, 4)) def normalize( self, year, month, data, coords = None) : diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py index ffc8c37..7cac3ed 100644 --- a/atmorep/datasets/normalizer_local.py +++ b/atmorep/datasets/normalizer_local.py @@ -24,24 +24,42 @@ class NormalizerLocal() : def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : - fname_base = '{}/normalization/{}/normalization_mean_var_{}_y{}_m{:02d}_{}{}.bin' + fname_base = './data/{}/normalization/{}/normalization_mean_var_{}_y{}_m{:02d}_{}{}.bin' + self.year_base = config.datasets[data_type]['extent'][0][0] + self.year_last = config.datasets[data_type]['extent'][0][1] + lat_min, lat_max = config.datasets[data_type]['extent'][1] + lat_min, lat_max = 90. - lat_min, 90. - lat_max + lat_min, lat_max = (lat_min, lat_max) if lat_min < lat_max else (lat_max, lat_min) + lon_min, lon_max = config.datasets[data_type]['extent'][2] + res = config.datasets[data_type]['resolution'][1] + is_global = config.datasets[data_type]['is_global'] self.corr_data = [ ] - for year in range( config.year_base, config.year_last+1) : + for year in range( self.year_base, self.year_last+1) : for month in range( 1, 12+1) : - corr_fname = fname_base.format( str(config.path_data), field_info[0], field_info[0], + corr_fname = fname_base.format( data_type, field_info[0], field_info[0], year, month, level_type, vlevel) + ns_lat = int( (lat_max-lat_min) / res + 1) + ns_lon = int( (lon_max-lon_min) / res + (0 if is_global else 1) ) + # TODO: remove file_shape (ns_lat, ns_lon contains same information) x = np.fromfile( corr_fname, dtype=np.float32).reshape( (file_shape[1], file_shape[2], 2)) - x = xr.DataArray( x, [ ('lat', np.linspace( 0., 180., num=180*4+1, endpoint=True)), - ('lon', np.linspace( 0., 360., num=360*4, endpoint=False)), + # TODO, TODO, TODO: remove once recomputed + if 'cerra' == data_type : + x[:,:,0] = 340. + x[:,:,1] = 600. + x = xr.DataArray( x, [ ('lat', np.linspace( lat_min, lat_max, num=ns_lat, endpoint=True)), + ('lon', np.linspace( lon_min, lon_max, num=ns_lon, endpoint=False)), ('data', ['mean', 'var']) ]) self.corr_data.append( x) def normalize( self, year, month, data, coords) : - corr_data_ym = self.corr_data[ (year - config.year_base) * 12 + month ] + corr_data_ym = self.corr_data[ (year - self.year_base) * 12 + (month-1) ] mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values + if (var == 0.).all() : + print( f'var == 0 :: ym : {year} / {month}') + assert False if len(data.shape) > 2 : for i in range( data.shape[0]) : @@ -53,7 +71,7 @@ def normalize( self, year, month, data, coords) : def denormalize( self, year, month, data, coords) : - corr_data_ym = self.corr_data[ (year - config.year_base) * 12 + month ] + corr_data_ym = self.corr_data[ (year - self.year_base) * 12 + (month-1) ] mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index a91ee0a..9fe5ae6 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -107,11 +107,15 @@ def load_json( self, wandb_id) : try : with open(fname, 'r') as f : json_str = f.readlines() - except IOError : + except (OSError, IOError) as e: # try path used for logging training results and checkpoints - fname = Path( config.path_results, '/models/id{}/model_id{}.json'.format( wandb_id, wandb_id)) - with open(fname, 'r') as f : - json_str = f.readlines() + try : + fname = Path( config.path_results, '/models/id{}/model_id{}.json'.format(wandb_id,wandb_id)) + with open(fname, 'r') as f : + json_str = f.readlines() + except (OSError, IOError) as e: + print( f'Could not find fname due to {e}. Aborting.') + quit() self.__dict__ = json.loads( json_str[0]) From 38ab84b59a3da2f5db2f99ba4ca813faa0f0d859 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 11 Mar 2024 13:47:08 +0100 Subject: [PATCH 02/66] - Switched to training data in zarr - Not fully functional yet, e.g. support for surface fields is missing or global forecast and similar things. - Added support for mixed precision training. --- atmorep/core/atmorep_model.py | 44 +- atmorep/core/train.py | 23 +- atmorep/core/trainer.py | 102 ++-- atmorep/datasets/multifield_data_sampler.py | 570 +++++--------------- atmorep/training/bert.py | 117 ++-- atmorep/utils/utils.py | 11 +- 6 files changed, 275 insertions(+), 592 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 3c4de6a..010919d 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -35,6 +35,7 @@ from atmorep.transformer.transformer_decoder import TransformerDecoder from atmorep.transformer.tail_ensemble import TailEnsemble + #################################################################################################### class AtmoRepData( torch.nn.Module) : @@ -77,8 +78,6 @@ def load_data( self, mode : NetMode, batch_size = -1, num_loader_workers = -1) : def _load_data( self, dataset, batch_size, num_loader_workers) : '''Private implementation for load''' - dataset.load_data( batch_size) - loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False, 'num_workers': num_loader_workers, 'pin_memory': True} data_loader = torch.utils.data.DataLoader( dataset, **loader_params, sampler = None) @@ -208,40 +207,14 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.pre_batch_targets = pre_batch_targets cf = self.net.cf - self.dataset_train = MultifieldDataSampler( cf.data_dir, cf.years_train, cf.fields, - batch_size = cf.batch_size_start, - num_t_samples = cf.num_t_samples, - num_patches_per_t = cf.num_patches_per_t_train, - num_load = cf.num_files_train, - pre_batch = self.pre_batch, - rng_seed = self.rng_seed, - file_shape = cf.file_shape, - smoothing = cf.data_smoothing, - level_type = cf.level_type, - file_format = cf.file_format, - month = cf.month, - time_sampling = cf.time_sampling, - geo_range = cf.geo_range_sampling, - fields_targets = cf.fields_targets, - pre_batch_targets = self.pre_batch_targets ) + self.dataset_train = MultifieldDataSampler( cf.fields, cf.levels, cf.years_train, + cf.batch_size_start, + pre_batch, cf.n_size, cf.num_samples_per_epoch ) - self.dataset_test = MultifieldDataSampler( cf.data_dir, cf.years_test, cf.fields, - batch_size = cf.batch_size_test, - num_t_samples = cf.num_t_samples, - num_patches_per_t = cf.num_patches_per_t_test, - num_load = cf.num_files_test, - pre_batch = self.pre_batch, - rng_seed = self.rng_seed, - file_shape = cf.file_shape, - smoothing = cf.data_smoothing, - level_type = cf.level_type, - file_format = cf.file_format, - month = cf.month, - time_sampling = cf.time_sampling, - geo_range = cf.geo_range_sampling, - lat_sampling_weighted = cf.lat_sampling_weighted, - fields_targets = cf.fields_targets, - pre_batch_targets = self.pre_batch_targets ) + self.dataset_test = MultifieldDataSampler( cf.fields, cf.levels, cf.years_test, + cf.batch_size_start, + pre_batch, cf.n_size, cf.num_samples_validate, + with_source_idxs = True ) return self @@ -261,7 +234,6 @@ def create( self, devices, load_pretrained=True) : cf = self.cf self.devices = devices - size_token_info = 6 self.fields_coupling_idx = [] self.fields_index = {} diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 51f290e..4ac97d7 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -136,9 +136,17 @@ def train() : # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ] # cf.fields_prediction = [ ['geopotential', 1.] ] - cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0], + [ 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], + [ 'velocity_v', [ 1, 1024, [ ], 0 ], + [ 114, 123, 137 ], + [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + cf.fields_prediction = [ [cf.fields[0][0], 0.5], [cf.fields[1][0], 0.5] ] + + cf.fields_targets = [] cf.years_train = [2021] # list( range( 1980, 2018)) cf.years_test = [2021] #[2018] @@ -229,6 +237,19 @@ def train() : cf.write_json( wandb) cf.print() + cf.levels = [114, 123, 137] + cf.with_mixed_precision = True + # cf.n_size = [36, 1*9*6, 1.*9*12] + # in steps x lat_degrees x lon_degrees + # cf.n_size = [36, 0.25*9*6, 0.25*9*12] + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + cf.num_samples_per_epoch = 1024 + cf.num_samples_validate = 128 + cf.num_loader_workers = 8 + + cf.years_train = [2021] # list( range( 1980, 2018)) + cf.years_test = [2021] #[2018] + trainer = Trainer_BERT( cf, device).create() trainer.run() diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 61f2574..e9b21ec 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -162,6 +162,8 @@ def run( self, epoch = -1) : self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=cf.lr_start, weight_decay=cf.weight_decay) + self.grad_scaler = torch.cuda.amp.GradScaler(enabled=cf.with_mixed_precision) + if 0 == cf.par_rank : # print( self.model.net) model_parameters = filter(lambda p: p.requires_grad, self.model_ddp.parameters()) @@ -240,17 +242,21 @@ def train( self, epoch): grad_loss_total = [] ctr = 0 + self.optimizer.zero_grad() + for batch_idx in range( model.len( NetMode.train)) : batch_data = self.model.next() - batch_data = self.prepare_batch( batch_data) - preds, _ = self.model_ddp( batch_data) + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=cf.with_mixed_precision): + batch_data = self.prepare_batch( batch_data) + preds, _ = self.model_ddp( batch_data) + loss, mse_loss, losses = self.loss( preds, batch_idx) - loss, mse_loss, losses = self.loss( preds, batch_idx) + self.grad_scaler.scale(loss).backward() + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() [loss_total[idx].append( losses[key]) for idx, key in enumerate(losses)] mse_loss_total.append( mse_loss.detach().cpu() ) @@ -371,21 +377,11 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): test_len = 0 self.mode_test = True - - # run in training mode - offset = 0 - if -1 == epoch and 0 == cf.par_rank : - if 1 == cf.num_accs_per_task : # bug in torchinfo; fixed in v1.8.0 - offset += 1 - print( 'Network size:') - batch_data = self.model.next() - batch_data = self.prepare_batch( batch_data) - torchinfo.summary( self.model, input_data=[batch_data]) # run test set evaluation with torch.no_grad() : - for it in range( self.model.len( NetMode.test) - offset) : + for it in range( self.model.len( NetMode.test)) : batch_data = self.model.next() if cf.par_rank < cf.log_test_num_ranks : @@ -393,19 +389,18 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] # targets if len(batch_data[1]) > 0 : - if type(batch_data[1][0][0]) is list : - targets = [batch_data[1][i][0][0] for i in range( len(batch_data[1]))] - else : - targets = batch_data[1][0] + targets = [] + for target_field in batch_data[1] : + targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) # store on cpu log_sources = ( [source.detach().clone().cpu() for source in sources ], [ti.detach().clone().cpu() for ti in token_infos], [target.detach().clone().cpu() for target in targets ], - tmis, tmis_list ) - - batch_data = self.prepare_batch( batch_data) + tmis, tmis_list ) - preds, atts = self.model( batch_data) + with torch.autocast(device_type='cuda',dtype=torch.float16,enabled=cf.with_mixed_precision): + batch_data = self.prepare_batch( batch_data) + preds, atts = self.model( batch_data) loss = torch.tensor( 0.) ifield = 0 @@ -492,6 +487,7 @@ def evaluate( self, data_idx = 0, log = True): for target_field in batch_data[1] : targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) # store on cpu + # TODO: is this still all needed with self.sources_idx log_sources = ( [source.detach().clone().cpu() for source in sources ], [ti.detach().clone().cpu() for ti in token_infos], [target.detach().clone().cpu() for target in targets ], @@ -624,6 +620,7 @@ def prepare_batch( self, xin) : # unpack loader output # xin[0] since BERT does not have targets (sources, token_infos, targets, fields_tokens_masked_idx,fields_tokens_masked_idx_list) = xin[0] + self.sources_idxs = xin[2] # network input batch_data = [ ( sources[i].to( devs[ cf.fields[i][1][3] ], non_blocking=True), @@ -641,10 +638,12 @@ def prepare_batch( self, xin) : self.targets.append( targets[ifield].to( devs[cf.fields[ifield][1][3]], non_blocking=True )) # idxs of masked tokens - tmi_out = [[] for _ in range(len(fields_tokens_masked_idx))] - for i,tmi in enumerate(fields_tokens_masked_idx) : - tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] - self.tokens_masked_idx = tmi_out + # tmi_out = [[] for _ in range(len(fields_tokens_masked_idx))] + # for i,tmi in enumerate(fields_tokens_masked_idx) : + # tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] + # self.tokens_masked_idx = tmi_out + self.tokens_masked_idx = [tmi.to(devs[cf.fields[i][1][3]], non_blocking=True) + for i,tmi in enumerate(fields_tokens_masked_idx)] # idxs of masked tokens per batch entry self.fields_tokens_masked_idx_list = fields_tokens_masked_idx_list @@ -674,51 +673,11 @@ def decoder_to_tail( self, idx_pred, pred) : # select "fixed" masked tokens for loss computation - # recover vertical level dimension - num_tokens = self.num_tokens[field_idx] - num_vlevels = len(self.cf.fields[field_idx][2]) # flatten token dimensions: remove space-time separation pred = torch.flatten( pred, 2, 3).to( dev) - # extract masked token level by level - pred_masked = [] - for lidx, level in enumerate(self.cf.fields[field_idx][2]) : - - # select masked tokens, flattened along batch dimension for easier indexing and processing - pred_l = torch.flatten( pred[:,lidx], 0, 1) - pred_masked_l = pred_l[ target_idx[lidx] ] - target_idx_l = target_idx[lidx] - - # add positional encoding of masked tokens - - # # TODO: do we need the positional encoding? - - # compute space time indices of all tokens - target_idxs_v = level * torch.ones( target_idx_l.shape[0], device=dev) - num_tokens_space = num_tokens[1] * num_tokens[2] - # remove offset introduced by linearization - target_idx_l = torch.remainder( target_idx_l, np.prod(num_tokens)) - target_idxs_t = (target_idx_l / num_tokens_space).int() - temp = torch.remainder( target_idx_l, num_tokens_space) - target_idxs_x = (temp / num_tokens[1]).int() - target_idxs_y = torch.remainder( temp, num_tokens[2]) - - # apply harmonic positional encoding - dim_embed = pred.shape[-1] - pe = torch.zeros( pred_masked_l.shape[0], dim_embed, device=dev) - xs = (2. * np.pi / dim_embed) * torch.arange( 0, dim_embed, 2, device=dev) - pe[:, 0::2] = 0.5 * torch.sin( torch.outer( 8 * target_idxs_x, xs) ) \ - + torch.sin( torch.outer( target_idxs_t, xs) ) - pe[:, 1::2] = 0.5 * torch.cos( torch.outer( 8 * target_idxs_y, xs) ) \ - + torch.cos( torch.outer( target_idxs_v, xs) ) - - # TODO: with or without final positional encoding? - # pred_masked.append( pred_masked_l + pe) - pred_masked.append( pred_masked_l) - - # flatten along level dimension, for loss evaluation we effectively have level, batch, ... - # as ordering of dimensions - pred_masked = torch.cat( pred_masked, 0) + pred_masked = torch.flatten( pred, 0, 2) + pred_masked = pred_masked[ target_idx ] return pred_masked @@ -739,6 +698,9 @@ def log_validate( self, epoch, bidx, log_sources, log_preds) : ################################################### def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : '''Logging for BERT_strategy=forecast.''' + + # TODO, TODO: use sources_idx + cf = self.cf detok = utils.detokenize diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 7e0c629..50de481 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -21,473 +21,198 @@ import code # code.interact(local=locals()) -from atmorep.datasets.dynamic_field_level import DynamicFieldLevel -from atmorep.datasets.static_field import StaticField +import zarr +import pandas as pd from atmorep.utils.utils import days_until_month_in_year from atmorep.utils.utils import days_in_month import atmorep.config.config as config +from atmorep.datasets.normalizer_global import NormalizerGlobal +from atmorep.datasets.normalizer_local import NormalizerLocal + + class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### - def __init__( self, file_path, years_data, fields, batch_size, - num_t_samples, num_patches_per_t, num_load, pre_batch, - rng_seed = None, file_shape = (-1, 721, 1440), - level_type = 'ml', time_sampling = 1, - smoothing = 0, file_format = 'grib', month = None, lat_sampling_weighted = True, - geo_range = [[-90.,90.], [0.,360.]], - fields_targets = [], pre_batch_targets = None - ) : + def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_samples_per_epoch, + rng_seed = None, time_sampling = 1, with_source_idxs = False, + fields_targets = None, pre_batch_targets = None ) : ''' Data set for single dynamic field at an arbitrary number of vertical levels + + nsize : neighborhood in (tsteps, deg_lat, deg_lon) ''' super( MultifieldDataSampler).__init__() self.fields = fields self.batch_size = batch_size + self.n_size = n_size + self.num_samples = num_samples_per_epoch + self.with_source_idxs = with_source_idxs self.pre_batch = pre_batch - self.years_data = years_data + # create (source) fields + # config.path_data + fname_source = '/p/scratch/atmo-rep/era5_res0025_1979.zarr' + fname_source = '/p/scratch/atmo-rep/era5_res0025_2021.zarr' + fname_source = '/p/scratch/atmo-rep/era5_res0025_2021_t5.zarr' + # fname_source = '/p/scratch/atmo-rep/era5_res0100_2021_t5.zarr' + self.ds = zarr.open( fname_source) + self.ds_global = self.ds.attrs['is_global'] + self.ds_len = self.ds['data'].shape[0] + + # sanity checking + # assert self.ds['data'].shape[0] == self.ds['time'].shape[0] + # assert self.ds_len >= num_samples_per_epoch + + self.lats = np.array( self.ds['lats']) + self.lons = np.array( self.ds['lons']) + + sh = self.ds['data'].shape + st = self.ds['time'].shape + print( f'self.ds[\'data\'] : {sh} :: {st}') + print( f'self.lats : {self.lats.shape}', flush=True) + print( f'self.lons : {self.lons.shape}', flush=True) + + self.fields_idxs = np.array( [self.ds.attrs['fields'].index( f[0]) for f in fields]) + self.levels_idxs = np.array( [self.ds.attrs['levels'].index( ll) for ll in levels]) + # self.fields_idxs = [0, 1, 2] + # self.levels_idxs = [0, 1] + self.levels = levels #[123, 137] # self.ds['levels'] + + # TODO + # # create (target) fields + # self.datasets_targets = self.create_loaders( fields_targets) + # self.fields_targets = fields_targets + # self.pre_batch_targets = pre_batch_targets + self.time_sampling = time_sampling - self.month = month - self.range_lat = 90. - np.array( geo_range[0]) - self.range_lon = np.array( geo_range[1]) - self.geo_range = geo_range - - # order North to South - self.range_lat = np.flip(self.range_lat) if self.range_lat[1] < self.range_lat[0] \ - else self.range_lat - - # prepare range_lat and range_lon for sampling - self.is_global = 0 == self.range_lat[0] and self.range_lon[0] == 0. \ - and 180. == self.range_lat[1] and 360. == self.range_lon[1] - - # TODO: this assumes file_shape is set correctly and not just per field and it defines a - # reference grid, likely has to be the coarsest - self.res = 360. / file_shape[2] + self.range_lat = np.array( self.lats[ [0,-1] ]) + self.range_lon = np.array( self.lons[ [0,-1] ]) + + self.res = np.zeros( 2) + self.res[0] = (self.range_lat[1]-self.range_lat[0]) / (self.ds['data'].shape[-2]-1) + self.res[1] = (self.range_lon[1]-self.range_lon[0]) / (self.ds['data'].shape[-1]-1) - # avoid wrap around at poles - pole_offset = np.ceil(fields[0][3][1] * fields[0][4][1] / 2) * self.res - self.range_lat[0] = pole_offset if self.range_lat[0] < pole_offset else self.range_lat[0] - self.range_lat[1] =180.-pole_offset if 180.-self.range_lat[1] South ordering - # shrink so that cookie cutting based on sampling does not exceed domain if it is not global - if not self.is_global : - # TODO: check that field data is consistent and covers the same spatial domain - # TODO: code below assumes that fields[0] is global - # TODO: code below does not handle anisotropic grids - finfo = self.fields[0] - # ensure that delta is a multiple of the coarse grid resolution - ngrid1 = finfo[3][1] * finfo[4][1] - ngrid2 = finfo[3][2] * finfo[4][2] - delta1 = 0.5 * self.res * (ngrid1-1 if ngrid1 % 2==0 else ngrid1+1) - delta2 = 0.5 * self.res * (ngrid2-1 if ngrid2 % 2==0 else ngrid2+1) - self.range_lat += np.array([delta1, -delta1]) - self.range_lon += np.array([delta2, -delta2]) + # ensure neighborhood does not exceed domain (either at pole or for finite domains) + self.range_lat += np.array([n_size[1] / 2., -n_size[1] / 2.]) + # lon: no change for periodic case + if self.ds_global < 1.: + self.range_lon += np.array([n_size[2]/2., -n_size[2]/2.]) # ensure all data loaders use same rng_seed and hence generate consistent data if not rng_seed : rng_seed = np.random.randint( 0, 100000, 1)[0] self.rng = np.random.default_rng( rng_seed) - # create (source) fields - self.datasets = self.create_loaders( fields) - - # create (target) fields - self.datasets_targets = self.create_loaders( fields_targets) - - ################################################### - def create_loaders( self, fields ) : - - datasets = [] - for field_idx, field_info in enumerate(fields) : - - datasets.append( []) - - # extract field info - (vls, num_tokens, token_size) = field_info[2:5] - - if len(field_info) > 6 : - corr_type = field_info[6] - else: - corr_type = 'global' - - smoothing = self.smoothing - log_transform_data = False - if len(field_info) > 7 : - (data_type, file_shape, file_geo_range, file_format) = field_info[7][:4] - if len( field_info[7]) > 6 : - smoothing = field_info[7][6] - print( '{} : smoothing = {}'.format( field_info[0], smoothing) ) - if len( field_info[7]) > 7 : - log_transform_data = field_info[7][7] - print( '{} : log_transform_data = {}'.format( field_info[0], log_transform_data) ) - else : - data_type = 'era5' - file_format = self.file_format - file_shape = self.file_shape - file_geo_range = [[90.,-90.], [0.,360.]] - - # static fields - if 0 == field_info[1][0] : - datasets[-1].append( StaticField( self.file_path, field_info, self.batch_size, data_type, - file_shape, file_geo_range, - num_tokens, token_size, smoothing, file_format, corr_type) ) - - # dynamic fields - elif 1 == field_info[1][0] : - for vlevel in vls : - datasets[-1].append( DynamicFieldLevel( self.file_path, self.years_data, field_info, - self.batch_size, data_type, - file_shape, file_geo_range, - num_tokens, token_size, - self.level_type, vlevel, self.time_sampling, - smoothing, file_format, corr_type, - log_transform_data ) ) - - else : - assert False - - return datasets + # data normalizers + self.normalizers = [] + for _, field_info in enumerate(fields) : + self.normalizers.append( []) + corr_type = 'global' if len(field_info) <= 6 else field_info[6] + ner = NormalizerGlobal if corr_type == 'global' else NormalizerLocal + for vl in self.levels : + self.normalizers[-1] += [ ner( field_info, vl, + np.array(self.ds['data'].shape)[[0,-2,-1]]) ] + + # extract indices for selected years + self.times = pd.DatetimeIndex( self.ds['time']) + # idxs = np.zeros( self.ds['time'].shape[0], dtype=np.bool_) + # self.idxs_years = np.array( []) + # for year in years : + # idxs = np.where( (self.times >= f'{year}-1-1') & (self.times <= f'{year}-12-31'))[0] + # assert idxs.shape[0] > 0, f'Requested year is not in dataset {fname_source}. Aborting.' + # self.idxs_years = np.append( self.idxs_years, idxs[::self.time_sampling]) + # TODO, TODO, TODO: + self.idxs_years = np.arange( self.ds_len) ################################################### def shuffle( self) : - # ensure that different parallel loaders create independent random shuffles - delta = torch.randint( 0, 100000, (1,)).item() - self.rng.bit_generator.advance( delta) + rng = self.rng + self.idxs_perm_t = rng.permutation( self.idxs_years)[:(self.num_samples // self.batch_size)] - self.idxs_perm = np.zeros( (0, 4), dtype=np.int64) + lats = rng.random(self.num_samples) * (self.range_lat[1] - self.range_lat[0]) +self.range_lat[0] + lons = rng.random(self.num_samples) * (self.range_lon[1] - self.range_lon[0]) +self.range_lon[0] - # latitude, first map to mathematical lat coords in [0,180.], then to [0,pi] then - # to z-value in [-1,1] - if self.lat_sampling_weighted : - lat_r = np.cos( self.range_lat/180. * np.pi) - else : - lat_r = self.range_lat - - # 1.00001 is a fudge factor since np.round(*.5) leads to flooring instead of proper up-rounding + # align with grid res_inv = 1.0 / self.res * 1.00001 + lats = self.res[0] * np.round( lats * res_inv[0]) + lons = self.res[1] * np.round( lons * res_inv[1]) - # loop over individual data year-month items - for i_ym in range( len(self.years_months)) : - - ym = self.years_months[i_ym] - - # ensure a constant size of work load of data loader independent of the month length - # factor of 128 is a fudge parameter to ensure that mod-ing leads to sufficiently - # random wrap-around (with 1 instead of 128 there is clustering on the first days) - hours_in_day = int( 24 / self.time_sampling) - time_slices = 128 * 31 * hours_in_day - time_slices_i_ym = hours_in_day * days_in_month( ym[0], ym[1]) - idxs_perm_temp = np.mod(self.rng.permutation(time_slices), time_slices_i_ym) - # fixed number of time samples independent of length of month - idxs_perm_temp = idxs_perm_temp[:self.num_t_samples] - idxs_perm = np.zeros( (self.num_patches_per_t *idxs_perm_temp.shape[0],4) ) - - # split up into file index and local index - idx = 0 - for it in idxs_perm_temp : - - idx_patches = self.rng.random( (self.num_patches_per_t, 2) ) - # for jj in idx_patches : - for jj in idx_patches : - # area consistent sampling on the sphere (with less patches close to the pole) - # see https://graphics.stanford.edu/courses/cs448-97-fall/notes.html , Lecture 7 - # for area preserving sampling of the sphere - # py \in [0,180], px \in [0,360] (possibly with negative values for lon) - if self.lat_sampling_weighted : - py = ((np.arccos(lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) / np.pi) * 180.) - else : - py = (lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) - px = jj[1] * (self.range_lon[1] - self.range_lon[0]) + self.range_lon[0] - - # align with grid - py = self.res * np.round( py * res_inv) - px = self.res * np.round( px * res_inv) - - idxs_perm[idx] = np.array( [i_ym, it, py, px]) - idx = idx + 1 - - self.idxs_perm = np.concatenate( (self.idxs_perm, idxs_perm[:idx])) - - # shuffle again to avoid clustering of patches by loop over idx_patches above - self.idxs_perm = self.idxs_perm[self.rng.permutation(self.idxs_perm.shape[0])] - self.idxs_perm = self.idxs_perm[self.rng.permutation(self.idxs_perm.shape[0])] - # restrict to multiples of batch size - lenbatch = int(math.floor(self.idxs_perm.shape[0] / self.batch_size)) * self.batch_size - self.idxs_perm = self.idxs_perm[:lenbatch] - # # DEBUG - # print( 'self.idxs_perm.shape = {}'.format(self.idxs_perm.shape )) - # rank = torch.distributed.get_rank() - # fname = 'idxs_perm_rank{}_{}.dat'.format( rank, shape_to_str( self.idxs_perm.shape)) - # self.idxs_perm.tofile( fname) + self.idxs_perm = np.stack( [lats, lons], axis=1) ################################################### - def set_full_time_range( self) : - - self.idxs_perm = np.zeros( (0, 4), dtype=np.int64) - - # latitude, first map to mathematical lat coords in [0,180.], then to [0,pi] then - # to z-value in [-1,1] - if self.lat_sampling_weighted : - lat_r = np.cos( self.range_lat/180. * np.pi) - else : - lat_r = self.range_lat - - # 1.00001 is a fudge factor since np.round(*.5) leads to flooring instead of proper up-rounding - res_inv = 1.0 / self.res * 1.00001 - - # loop over individual data year-month items - for i_ym in range( len(self.years_months)) : - - ym = self.years_months[i_ym] - - hours_in_day = int( 24 / self.time_sampling) - idxs_perm_temp = np.arange( hours_in_day * days_in_month( ym[0], ym[1])) - idxs_perm = np.zeros( (self.num_patches_per_t *idxs_perm_temp.shape[0],4) ) - - # split up into file index and local index - idx = 0 - for it in idxs_perm_temp : - - idx_patches = self.rng.random( (self.num_patches_per_t, 2) ) - for jj in idx_patches : - # area consistent sampling on the sphere (with less patches close to the pole) - # see https://graphics.stanford.edu/courses/cs448-97-fall/notes.html , Lecture 7 - # for area preserving sampling of the sphere - # py \in [0,180], px \in [0,360] (possibly with negative values for lon) - if self.lat_sampling_weighted : - py = ((np.arccos(lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) / np.pi) * 180.) - else : - py = (lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) - px = jj[1] * (self.range_lon[1] - self.range_lon[0]) + self.range_lon[0] - - # align with grid - py = self.res * np.round( py * res_inv) - px = self.res * np.round( px * res_inv) - - idxs_perm[idx] = np.array( [i_ym, it, py, px]) - idx = idx + 1 - - self.idxs_perm = np.concatenate( (self.idxs_perm, idxs_perm[:idx])) - - # shuffle again to avoid clustering of patches by loop over idx_patches above - self.idxs_perm = self.idxs_perm[self.rng.permutation(self.idxs_perm.shape[0])] - # restrict to multiples of batch size - lenbatch = int(math.floor(self.idxs_perm.shape[0] / self.batch_size)) * self.batch_size - self.idxs_perm = self.idxs_perm[:lenbatch] - - # # DEBUG - # print( 'self.idxs_perm.shape = {}'.format(self.idxs_perm.shape )) - # fname = 'idxs_perm_{}_{}.dat'.format( self.epoch_counter, shape_to_str( self.idxs_perm.shape)) - # self.idxs_perm.tofile( fname) - - ################################################### - def load_data( self, batch_size = None) : + def __iter__(self): - years_data = self.years_data - - # ensure proper separation of different random samplers - delta = torch.randint( 0, 1000, (1,)).item() - self.rng.bit_generator.advance( delta) - - # select num_load random months and years - perms = np.concatenate( [self.rng.permutation( np.arange(len(years_data))) for i in range(64)]) - perms = perms[:self.num_load] - if self.month : - self.years_months = [ (years_data[iyear], self.month) for iyear in perms] - else : - # stratified sampling of month to ensure proper distribution, needs to be adapted for - # number of parallel workers not being divisible by 4 - # rank, ms = torch.distributed.get_rank() % 4, 3 - # perms_m = np.concatenate( [self.rng.permutation( np.arange( rank*ms+1, (rank+1)*ms+1)) - # for i in range(16)]) - perms_m = np.concatenate( [self.rng.permutation( np.arange( 1, 12+1)) for i in range(16)]) - self.years_months = [ ( years_data[iyear], perms_m[i]) for i,iyear in enumerate(perms)] - - # generate random permutations passed to the loaders for individual files - # to ensure consistent processing + # TODO: if we keep this then we should remove the rng_seed argument for the constuctor + self.rng = np.random.default_rng() self.shuffle() - # perform actual loading of data - - for ds_field in self.datasets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) - - for ds_field in self.datasets_targets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) - - ################################################### - def set_data( self, times_pos, batch_size = None) : - ''' - times_pos = np.array( [ [year, month, day, hour, lat, lon], ...] ) - - lat \in [90,-90] = [90N, 90S] - - lon \in [0,360] - - (year,month) pairs should be a limited number since all data for these is loaded - ''' - - # extract required years and months - years_months_all = np.array( [ [it[0], it[1]] for it in times_pos ], dtype=np.int64) - self.years_months = list( zip( np.unique(years_months_all[:,0]), - np.unique( years_months_all[:,1] ))) - - # generate all the data - self.idxs_perm = np.zeros( (len(times_pos), 4)) - for idx, item in enumerate( times_pos) : - - assert item[2] >= 1 and item[2] <= 31 - assert item[3] >= 0 and item[3] < int(24 / self.time_sampling) - assert item[4] >= -90. and item[4] <= 90. - - # find year - for i_ym, ym in enumerate( self.years_months) : - if ym[0] == item[0] and ym[1] == item[1] : - break - - # last term: correct for window from last file that is loaded - it = (item[2] - 1) * (24./self.time_sampling) + item[3] - # it = item[2] * (24./self.time_sampling) + item[3] - idx_lat = item[4] - idx_lon = item[5] - - # work with mathematical lat coordinates from here on - self.idxs_perm[idx] = np.array( [i_ym, it, 90. - idx_lat, idx_lon]) - - for ds_field in self.datasets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) - - for ds_field in self.datasets_targets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) - - ################################################### - def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : - ''' generate patch/token positions for global grid ''' - - token_overlap = torch.tensor( token_overlap).to(torch.int64) - - # assumed that sanity checking that field data is consistent has been done - ifield = 0 - field = self.fields[ifield] - + lats, lons = self.lats, self.lons + fields_idxs, levels_idxs = self.fields_idxs, self.levels_idxs + ts, n_size = self.time_sampling, self.n_size + ns_2 = np.array(self.n_size) / 2. res = self.res - side_len = torch.tensor( [field[3][1] * field[4][1] * res, field[3][2] * field[4][2] * res] ) - overlap = torch.tensor( [token_overlap[0]*field[4][1]*res, token_overlap[1]*field[4][2]*res] ) - side_len_2 = side_len / 2. - assert all( overlap <= side_len_2), 'token_overlap too large for #tokens, reduce if possible' - - # generate tiles - times_pos = [] - for ctime in times : - - lat = side_len_2[0].item() - num_tiles_lat = 0 - while (lat + side_len_2[0].item()) < 180. : - num_tiles_lat += 1 - lon = side_len_2[1].item() - overlap[1].item()/2. - num_tiles_lon = 0 - while (lon - side_len_2[1]) < 360. : - times_pos += [[*ctime, -lat + 90., np.mod(lon,360.) ]] - lon += side_len[1].item() - overlap[1].item() - num_tiles_lon += 1 - lat += side_len[0].item() - overlap[0].item() - - # add one additional row if no perfect tiling (sphere is toric in longitude so no special - # handling necessary but not in latitude) - # the added row is such that it goes exaclty down to the South pole and the offset North-wards - # is computed based on this - lat -= side_len[0] - overlap[0] - if lat - side_len_2[0] < 180. : - num_tiles_lat += 1 - lat = 180. - side_len_2[0].item() + res - lon = side_len_2[1].item() - overlap[1].item()/2. - while (lon - side_len_2[1]) < 360. : - times_pos += [[*ctime, -lat + 90., np.mod(lon,360.) ]] - lon += side_len[1].item() - overlap[1].item() - - # adjust batch size if necessary so that the evaluations split up across batches of equal size - batch_size = num_tiles_lon - - print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) - - self.set_data( times_pos, batch_size) - - ################################################### - def set_location( self, pos, years, months, num_t_samples_per_month, batch_size = None) : - ''' random time sampling for fixed location ''' - - times_pos = [] - for i_ym, ym in enumerate(itertools.product( years, months )) : - - # ensure a constant size of work load of data loader independent of the month length - # factor of 128 is a fudge parameter to ensure that mod-ing leads to sufficiently - # random wrap-around (with 1 instead of 128 there is clustering on the first days) - hours_in_day = int( 24 / self.time_sampling) - d_i_m = days_in_month( ym[0], ym[1]) - perms = self.rng.permutation( num_t_samples_per_month * d_i_m) - # ensure that days start at 1 - perms = np.mod( perms[ : num_t_samples_per_month], (d_i_m-1) ) + 1 - rhs = self.rng.integers(low=0, high=hours_in_day, size=num_t_samples_per_month ) - - for rh, perm in zip( rhs, perms) : - times_pos += [[ ym[0], ym[1], perm, rh, pos[0], pos[1]] ] - - # adjust batch size if necessary so that the evaluations split up across batches of equal size - while 0 != (len(times_pos) % batch_size) : - batch_size -= 1 - assert batch_size >= 1 - - self.set_data( times_pos, batch_size) - - ################################################### - def __iter__(self): iter_start, iter_end = self.worker_workset() for bidx in range( iter_start, iter_end) : + + idx = self.idxs_perm_t[bidx] + idxs_t = list(np.arange( idx-n_size[0]*ts, idx, ts, dtype=np.int64)) + data_t = self.ds['data'].oindex[ idxs_t, fields_idxs , levels_idxs] + + sources, sources_infos, source_idxs = [], [], [] + for sidx in range(self.batch_size) : + + idx = self.idxs_perm[bidx*self.batch_size+sidx] + + # slight assymetry with offset by res/2 is required to match desired token count + lat_ran = np.where(np.logical_and(lats > idx[0]-ns_2[1]-res[0]/2.,lats < idx[0]+ns_2[1]))[0] + # handle periodicity of lon + assert not ((idx[1]-ns_2[2]) < 0. and (idx[1]+ns_2[2]) > 360.) + il, ir = (idx[1]-ns_2[2]-res[1]/2., idx[1]+ns_2[2]) + if il < 0. : + lon_ran = np.concatenate( [np.where( lons > il+360)[0], np.where(lons < ir)[0]], 0) + elif ir > 360. : + lon_ran = np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0) + else : + lon_ran = np.where(np.logical_and( lons > il, lons < ir))[0] + + # extract data + source = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) + sources += [ np.expand_dims(source, 0) ] + if self.with_source_idxs : + source_idxs += [ (idxs_t, lat_ran, lon_ran) ] + + # normalize data + # TODO: temporal window can span multiple months + year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month + for ifield, _ in enumerate(fields_idxs) : + for ilevel, _ in enumerate(levels_idxs) : + nf = self.normalizers[ifield][ilevel].normalize + source[:,ifield,ilevel] = nf( year, month, source[:,ifield,ilevel], (lat_ran, lon_ran)) + + # extract batch info + sources_infos += [ [ self.ds['time'][ idxs_t ], self.levels, + self.lats[lat_ran], self.lons[lon_ran], self.res ] ] + + # swap + sources = self.pre_batch( torch.from_numpy( np.concatenate( sources, 0)), + sources_infos ) + + # TODO: implement targets + target, target_info = None, None - sources = [] - for ds_field in self.datasets : - sources.append( [ds_level[bidx] for ds_level in ds_field]) - # perform batch pre-processing, e.g. BERT-type masking - if self.pre_batch : - sources = self.pre_batch( sources) - - targets = [] - for ds_field in self.datasets_targets : - targets.append( [ds_level[bidx] for ds_level in ds_field]) - # perform batch pre-processing, e.g. BERT-type masking - if self.pre_batch_targets : - targets = self.pre_batch_targets( targets) - - yield (sources,targets) + yield ( sources, (target, target_info), source_idxs ) ################################################### def __len__(self): - return len(self.datasets[0][0]) + return self.num_samples // self.batch_size ################################################### def worker_workset( self) : @@ -496,17 +221,16 @@ def worker_workset( self) : if worker_info is None: iter_start = 0 - iter_end = len(self.datasets[0][0]) + iter_end = self.num_samples else: # split workload - temp = len(self.datasets[0][0]) - per_worker = int( np.floor( temp / float(worker_info.num_workers) ) ) + per_worker = len(self) // worker_info.num_workers worker_id = worker_info.id iter_start = int(worker_id * per_worker) iter_end = int(iter_start + per_worker) if worker_info.id+1 == worker_info.num_workers : - iter_end = int(temp) + iter_end = len(self) return iter_start, iter_end - + diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 51ccdd2..3b2f4a8 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -22,7 +22,7 @@ from atmorep.utils.utils import tokenize #################################################################################################### -def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data) : +def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, fields_infos) : fields_tokens_masked_idx = [[] for _ in fields_data] fields_tokens_masked_idx_list = [[] for _ in fields_data] @@ -59,61 +59,65 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data) size_lat = int(rngs[0].integers( 2, fields[0][3][1]+1, 1)[0] / 2.) * 2 size_lon = int(rngs[0].integers( 2, fields[0][3][2]+1, 1)[0] / 2.) * 2 - rng_idx = 1 - for ifield, data_field in enumerate(fields_data) : - for ilevel, (field_data, token_info) in enumerate(data_field) : - - tok_size = fields[ifield][4] - field_data = tokenize( field_data, tok_size ) - field_data_shape = field_data.shape - - # cut neighborhood for current batch - if cf.BERT_window : - # adjust size based on token size so that one has a fixed size window in physical space - cur_size_t = int(size_t * fields[ifield][3][0] / fields[0][3][0]) - cur_size_lat = int(size_lat * fields[ifield][3][1] / fields[0][3][1]) - cur_size_lon = int(size_lon * fields[ifield][3][2] / fields[0][3][2]) - # define indices - idx_t_s = field_data.shape[1] - cur_size_t - idx_lat_s = field_data.shape[2] - cur_size_lat - idx_lon_s = field_data.shape[3] - cur_size_lon - # cut - field_data = field_data[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] - field_data = field_data.contiguous() - # for token info first recover space-time shape - token_info = token_info.reshape( list(field_data_shape[0:4]) + [token_info.shape[-1]]) - token_info = token_info[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] - token_info = torch.flatten( token_info, 1, -2) - token_info = token_info.contiguous() + # swap fields (idx=2) in first position for iteration and time (idx=1) before spatial coordinates + # fields_data = fields_data.permute( [3, 0, 2, 1, 4, 5]) + fields_data = fields_data.permute( [2, 0, 3, 1, 4, 5]) + fields_data_out, fields_infos_out, fields_targets = [], [], [] + fields_tokens_masked_idx, fields_tokens_masked_idx_list = [], [] + + rng_idx = 0 + for ifield, field_data in enumerate(fields_data) : + + tok_size = fields[ifield][4] + field_data = tokenize( field_data, tok_size ) + + # # cut neighborhood for current batch + # if cf.BERT_window : + # # adjust size based on token size so that one has a fixed size window in physical space + # cur_size_t = int(size_t * fields[ifield][3][0] / fields[0][3][0]) + # cur_size_lat = int(size_lat * fields[ifield][3][1] / fields[0][3][1]) + # cur_size_lon = int(size_lon * fields[ifield][3][2] / fields[0][3][2]) + # # define indices + # idx_t_s = field_data.shape[1] - cur_size_t + # idx_lat_s = field_data.shape[2] - cur_size_lat + # idx_lon_s = field_data.shape[3] - cur_size_lon + # # cut + # field_data = field_data[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] + # field_data = field_data.contiguous() + # # for token info first recover space-time shape + # token_info = token_info.reshape( list(field_data_shape[0:4]) + [token_info.shape[-1]]) + # token_info = token_info[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] + # token_info = torch.flatten( token_info, 1, -2) + # token_info = token_info.contiguous() + + # no masking for static fields or if masking rate = 0 + if fields[ifield][1][0] > 0 and fields[ifield][5][0] > 0. : + + ret = bert_f( cf, ifield, field_data, rngs[rng_idx]) + (field_data, target, tokens_masked_idx, tokens_masked_idx_list) = ret - # no masking for static fields or if masking rate = 0 - if fields[ifield][1][0] > 0 and fields[ifield][5][0] > 0. : - - ret = bert_f( cf, ifield, field_data, token_info, rngs[rng_idx]) - (field_data, token_info, target, tokens_masked_idx, tokens_masked_idx_list) = ret - - if target != None : - fields_targets[ifield].append( target) - fields_tokens_masked_idx[ifield].append( tokens_masked_idx) - fields_tokens_masked_idx_list[ifield].append( tokens_masked_idx_list) - - rng_idx += 1 - - sources[ifield].append( field_data.unsqueeze(1) ) - token_infos[ifield].append( token_info ) - - # merge along vertical level - sources[ifield] = torch.cat( sources[ifield], 1) - token_infos[ifield] = torch.cat( token_infos[ifield], 1) - # merge along vertical level, for target we have level, batch, ... ordering - fields_targets[ifield] = torch.cat( fields_targets[ifield],0) \ - if len(fields_targets[ifield]) > 0 else fields_targets[ifield] - - return (sources, token_infos, fields_targets, fields_tokens_masked_idx, + if target is not None : + fields_targets.append( target) + fields_tokens_masked_idx.append( tokens_masked_idx) + fields_tokens_masked_idx_list.append( tokens_masked_idx_list) + + rng_idx += 1 + + fields_data_out.append( field_data ) + fields_infos_out.append( torch.zeros( (*field_data.shape[:5], 8)) ) + + # merge + # sources[ifield] = torch.cat( sources[ifield], 1) + # token_infos[ifield] = torch.cat( token_infos[ifield], 1) + # # merge along vertical level, for target we have level, batch, ... ordering + # fields_targets[ifield] = torch.cat( fields_targets[ifield],0) \ + # if len(fields_targets[ifield]) > 0 else fields_targets[ifield] + + return (fields_data_out, fields_infos_out, fields_targets, fields_tokens_masked_idx, fields_tokens_masked_idx_list) #################################################################################################### -def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : +def prepare_batch_BERT_field( cf, ifield, source, rng) : # shortcuts mr = partial( torch.nn.functional.interpolate, mode='trilinear') @@ -125,7 +129,7 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # collapse token dimensions source_shape0 = source.shape - source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) + source = torch.flatten( torch.flatten( source, 1, 4), 2, 4) # select random token in the selected space-time cube to be masked/deleted BERT_frac = cf.fields[ifield][5][0] @@ -192,16 +196,11 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # unsqueeze(usq()) is required since channel dimension is expected temp = mr( mr( usq( source[ idx[idx_mr_cond] ].reshape( (-1,ts[0],ts[1],ts[2])), 1), mrs), ts) source[ idx[idx_mr_cond] ] = sq( fl( temp, -3, -1)) - # adjust resolution parameter in token_info - token_info_shape = token_info.shape - token_info = token_info.flatten( 0, 1) - token_info[ idx[idx_mr_cond] ][-1] *= (mrs[1] + mrs[2]) / 2. #TODO: anisotropic resolution - token_info = token_info.reshape( token_info_shape) # recover batch dimension which was flattend for easier indexing and also token dimensions source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx, tokens_masked_idx_list) + return (source, target, tokens_masked_idx, tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index 9fe5ae6..f600eaf 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -293,9 +293,14 @@ def tokenize( data, token_size = [-1,-1,-1]) : tok_tot_x = int( data_shape[-2] / token_size[1]) tok_tot_y = int( data_shape[-1] / token_size[2]) - if 4 == len(data_shape) : - t2 = torch.reshape( data, (-1, tok_tot_t, token_size[0], tok_tot_x, token_size[1], tok_tot_y, token_size[2])) - data_tokenized = torch.transpose(torch.transpose( torch.transpose( t2, 5, 4), 3, 2), 4, 3) + if 5 == len(data_shape) : + t2 = torch.reshape( data, (data.shape[0], data.shape[1], tok_tot_t, token_size[0], + tok_tot_x, token_size[1], tok_tot_y, token_size[2])) + data_tokenized = t2.permute( [0, 1, 2, 4, 6, 3, 5, 7]) + elif 4 == len(data_shape) : + t2 = torch.reshape( data, (-1, tok_tot_t, token_size[0], + tok_tot_x, token_size[1], tok_tot_y, token_size[2])) + data_tokenized = t2.permute( [0, 1, 3, 5, 4, 3, 6]) elif 3 == len(data_shape) : t2 = torch.reshape( data, (tok_tot_t, token_size[0], tok_tot_x, token_size[1], tok_tot_y, token_size[2])) data_tokenized = torch.transpose(torch.transpose( torch.transpose( t2, 4, 3), 2, 1), 3, 2) From 29dd3e6438a50be14fe8c0a21d45b301b904e382 Mon Sep 17 00:00:00 2001 From: iluise Date: Tue, 19 Mar 2024 17:50:42 +0100 Subject: [PATCH 03/66] normalization and tokenization --- atmorep/core/atmorep_model.py | 4 +- atmorep/core/train.py | 10 +- atmorep/core/train_multi.py | 1 - atmorep/datasets/multifield_data_sampler.py | 113 ++++++++++++++------ atmorep/datasets/normalizer_global.py | 2 + atmorep/datasets/normalizer_local.py | 2 +- atmorep/training/bert.py | 113 +++++++------------- atmorep/utils/utils.py | 1 + 8 files changed, 131 insertions(+), 115 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 010919d..05d927a 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -163,10 +163,10 @@ def normalizer( self, field, vl_idx) : def mode( self, mode : NetMode) : if mode == NetMode.train : - self.data_loader_iter = iter(self.data_loader_train) + self.data_loader_iter = iter(self.dataset_train) #iter(self.data_loader_train) self.net.train() elif mode == NetMode.test : - self.data_loader_iter = iter(self.data_loader_test) + self.data_loader_iter = iter(self.dataset_test) #iter(self.data_loader_test) self.net.eval() else : assert False diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 4ac97d7..f9bdf2f 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -143,7 +143,10 @@ def train() : [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], [ 'velocity_v', [ 1, 1024, [ ], 0 ], [ 114, 123, 137 ], - [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ], #] + [ 'total_precip', [ 1, 1536, [ ], 3 ], + [ 0 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] cf.fields_prediction = [ [cf.fields[0][0], 0.5], [cf.fields[1][0], 0.5] ] @@ -169,7 +172,7 @@ def train() : cf.batch_size_max = 32 cf.batch_size_delta = 8 cf.num_epochs = 128 - cf.num_loader_workers = 8 + # cf.num_loader_workers = 1#8 # additional infos cf.size_token_info = 8 cf.size_token_info_net = 16 @@ -216,7 +219,6 @@ def train() : # BERT # strategies: 'BERT', 'forecast', 'temporal_interpolation', 'identity' cf.BERT_strategy = 'BERT' - cf.BERT_window = False # sample sub-region cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution @@ -245,7 +247,7 @@ def train() : cf.n_size = [36, 0.25*9*6, 0.25*9*12] cf.num_samples_per_epoch = 1024 cf.num_samples_validate = 128 - cf.num_loader_workers = 8 + cf.num_loader_workers = 1 #8 cf.years_train = [2021] # list( range( 1980, 2018)) cf.years_test = [2021] #[2018] diff --git a/atmorep/core/train_multi.py b/atmorep/core/train_multi.py index 3f057cd..fc855ac 100644 --- a/atmorep/core/train_multi.py +++ b/atmorep/core/train_multi.py @@ -194,7 +194,6 @@ def train_multi() : cf.lat_sampling_weighted = False # BERT cf.BERT_strategy = 'BERT' # 'BERT', 'forecast', 'identity', 'totalmask' - cf.BERT_window = False # sample sub-region cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 50de481..173c213 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -20,10 +20,10 @@ import itertools import code # code.interact(local=locals()) - import zarr import pandas as pd - +import pdb +import code from atmorep.utils.utils import days_until_month_in_year from atmorep.utils.utils import days_in_month @@ -31,6 +31,7 @@ from atmorep.datasets.normalizer_global import NormalizerGlobal from atmorep.datasets.normalizer_local import NormalizerLocal +from atmorep.utils.utils import tokenize class MultifieldDataSampler( torch.utils.data.IterableDataset): @@ -53,12 +54,12 @@ def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_sa self.with_source_idxs = with_source_idxs self.pre_batch = pre_batch - + # create (source) fields # config.path_data fname_source = '/p/scratch/atmo-rep/era5_res0025_1979.zarr' fname_source = '/p/scratch/atmo-rep/era5_res0025_2021.zarr' - fname_source = '/p/scratch/atmo-rep/era5_res0025_2021_t5.zarr' + fname_source = '/p/scratch/atmo-rep/data/era5_1deg/era5_res0025_2021_final.zarr' # fname_source = '/p/scratch/atmo-rep/era5_res0100_2021_t5.zarr' self.ds = zarr.open( fname_source) self.ds_global = self.ds.attrs['is_global'] @@ -76,9 +77,12 @@ def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_sa print( f'self.ds[\'data\'] : {sh} :: {st}') print( f'self.lats : {self.lats.shape}', flush=True) print( f'self.lons : {self.lons.shape}', flush=True) + self.fields_idxs = [] - self.fields_idxs = np.array( [self.ds.attrs['fields'].index( f[0]) for f in fields]) - self.levels_idxs = np.array( [self.ds.attrs['levels'].index( ll) for ll in levels]) + # for f in fields: + # self.fields_idxs = [self.ds.attrs['fields'].index( f[0]) if f[0] in self.ds.attrs['fields'] + # self.fields_idxs = np.array( [self.ds.attrs['fields'].index( f[0]) for f in fields]) + # self.levels_idxs = np.array( [self.ds.attrs['levels'].index( ll) for ll in levels]) # self.fields_idxs = [0, 1, 2] # self.levels_idxs = [0, 1] self.levels = levels #[123, 137] # self.ds['levels'] @@ -94,8 +98,8 @@ def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_sa self.range_lon = np.array( self.lons[ [0,-1] ]) self.res = np.zeros( 2) - self.res[0] = (self.range_lat[1]-self.range_lat[0]) / (self.ds['data'].shape[-2]-1) - self.res[1] = (self.range_lon[1]-self.range_lon[0]) / (self.ds['data'].shape[-1]-1) + self.res[0] = self.ds.attrs['resol'][0] #(self.range_lat[1]-self.range_lat[0]) / (self.ds['data'].shape[-2]-1) + self.res[1] = self.ds.attrs['resol'][1] #(self.range_lon[1]-self.range_lon[0]) / (self.ds['data'].shape[-1]-1) # ensure neighborhood does not exceed domain (either at pole or for finite domains) self.range_lat += np.array([n_size[1] / 2., -n_size[1] / 2.]) @@ -114,10 +118,10 @@ def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_sa self.normalizers.append( []) corr_type = 'global' if len(field_info) <= 6 else field_info[6] ner = NormalizerGlobal if corr_type == 'global' else NormalizerLocal - for vl in self.levels : + for vl in field_info[2]: #self.levels : + data_type = 'data_sfc' if vl == 0 else 'data' #surface field self.normalizers[-1] += [ ner( field_info, vl, - np.array(self.ds['data'].shape)[[0,-2,-1]]) ] - + np.array(self.ds[data_type].shape)[[0,-2,-1]]) ] # extract indices for selected years self.times = pd.DatetimeIndex( self.ds['time']) # idxs = np.zeros( self.ds['time'].shape[0], dtype=np.bool_) @@ -153,7 +157,7 @@ def __iter__(self): self.shuffle() lats, lons = self.lats, self.lons - fields_idxs, levels_idxs = self.fields_idxs, self.levels_idxs + #fields_idxs, levels_idxs = self.fields_idxs, self.levels_idxs ts, n_size = self.time_sampling, self.n_size ns_2 = np.array(self.n_size) / 2. res = self.res @@ -164,46 +168,87 @@ def __iter__(self): idx = self.idxs_perm_t[bidx] idxs_t = list(np.arange( idx-n_size[0]*ts, idx, ts, dtype=np.int64)) - data_t = self.ds['data'].oindex[ idxs_t, fields_idxs , levels_idxs] - + data_t = [] + + for _, field_info in enumerate(self.fields) : + data_lvl = [] + for vl in field_info[2]: + if vl == 0: #surface level + field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) + data_lvl += [self.ds['data_sfc'].oindex[ idxs_t, field_idx]] + else: + field_idx = self.ds.attrs['fields'].index( field_info[0]) + vl_idx = self.ds.attrs['levels'].index(vl) + data_lvl += [self.ds['data'].oindex[ idxs_t, field_idx, vl_idx]] + data_t += [data_lvl] + sources, sources_infos, source_idxs = [], [], [] + lat_ran = [] + lon_ran = [] for sidx in range(self.batch_size) : idx = self.idxs_perm[bidx*self.batch_size+sidx] # slight assymetry with offset by res/2 is required to match desired token count - lat_ran = np.where(np.logical_and(lats > idx[0]-ns_2[1]-res[0]/2.,lats < idx[0]+ns_2[1]))[0] + lat_ran += [np.where(np.logical_and(lats > idx[0]-ns_2[1]-res[0]/2.,lats < idx[0]+ns_2[1]))[0]] # handle periodicity of lon assert not ((idx[1]-ns_2[2]) < 0. and (idx[1]+ns_2[2]) > 360.) il, ir = (idx[1]-ns_2[2]-res[1]/2., idx[1]+ns_2[2]) if il < 0. : - lon_ran = np.concatenate( [np.where( lons > il+360)[0], np.where(lons < ir)[0]], 0) + lon_ran += [np.concatenate( [np.where( lons > il+360)[0], np.where(lons < ir)[0]], 0)] elif ir > 360. : - lon_ran = np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0) + lon_ran += [np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0)] else : - lon_ran = np.where(np.logical_and( lons > il, lons < ir))[0] - - # extract data - source = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) - sources += [ np.expand_dims(source, 0) ] + lon_ran += [np.where(np.logical_and( lons > il, lons < ir))[0]] + if self.with_source_idxs : source_idxs += [ (idxs_t, lat_ran, lon_ran) ] + + source = [] + source_infos = [] + # extract data + # TODO: temporal window can span multiple months + year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month + for ifield, field_info in enumerate(self.fields): + source_lvl = [] + source_info_lvl = [] + tok_size = field_info[4] + for ilevel, vl in enumerate(field_info[2]): #self.levels : + nf = self.normalizers[ifield][ilevel].normalize + source_data, info_data = [], [] + + for sidx in range(self.batch_size) : + #normalize and tokenize + source_data += [ tokenize( torch.from_numpy(nf( year, month, np.take( np.take( data_t[ifield][ilevel], + lat_ran[sidx], -2), lon_ran[sidx], -1), (lat_ran[sidx], lon_ran[sidx]))), tok_size ) ] + # info_data += [ torch.Tensor(self.ds['time'][ idxs_t ], vl, + # self.lats[lat_ran[sidx]], self.lons[lon_ran[sidx]], self.res) ] + + info_data += [ [self.ds['time'][ idxs_t ], vl, + self.lats[lat_ran[sidx]], self.lons[lon_ran[sidx]], self.res] ] + source_lvl += [torch.stack(source_data, dim = 0)] + source_info_lvl += [info_data] + source += [source_lvl] + source_infos += [source_info_lvl] + + sources = self.pre_batch(source, #torch.from_numpy( np.concatenate( sources, 0)), + source_infos ) # [ source ] # [ np.expand_dims(source, 0) ] + sources_infos = [sources_infos] - # normalize data - # TODO: temporal window can span multiple months - year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month - for ifield, _ in enumerate(fields_idxs) : - for ilevel, _ in enumerate(levels_idxs) : - nf = self.normalizers[ifield][ilevel].normalize - source[:,ifield,ilevel] = nf( year, month, source[:,ifield,ilevel], (lat_ran, lon_ran)) + #sources += [ self.pre_batch( source, #torch.from_numpy( np.concatenate( sources, 0)), + # source_infos ) ] # [ source ] # [ np.expand_dims(source, 0) ] + #sources_infos += [source_infos] + # breakpoint() + # if self.with_source_idxs : + # source_idxs += [ (idxs_t, lat_ran, lon_ran) ] - # extract batch info - sources_infos += [ [ self.ds['time'][ idxs_t ], self.levels, - self.lats[lat_ran], self.lons[lon_ran], self.res ] ] + # # extract batch info + # sources_infos += [ [ self.ds['time'][ idxs_t ], self.levels, + # self.lats[lat_ran], self.lons[lon_ran], self.res ] ] # swap - sources = self.pre_batch( torch.from_numpy( np.concatenate( sources, 0)), - sources_infos ) + # sources = self.pre_batch( sources, #torch.from_numpy( np.concatenate( sources, 0)), + # sources_infos ) # TODO: implement targets target, target_info = None, None diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py index 8bf997a..c5558b9 100644 --- a/atmorep/datasets/normalizer_global.py +++ b/atmorep/datasets/normalizer_global.py @@ -32,6 +32,8 @@ def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_typ def normalize( self, year, month, data, coords = None) : corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), self.corr_data[:,1] == float(month))) , 2:].flatten() + data_temp = (data - corr_data_ym[0]) / corr_data_ym[1] + print(data_temp.mean(), data_temp.std()) return (data - corr_data_ym[0]) / corr_data_ym[1] def denormalize( self, year, month, data, coords = None) : diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py index 7cac3ed..32f1763 100644 --- a/atmorep/datasets/normalizer_local.py +++ b/atmorep/datasets/normalizer_local.py @@ -66,7 +66,7 @@ def normalize( self, year, month, data, coords) : data[i] = (data[i] - mean) / var else : data = (data - mean) / var - + print(data.mean(), data.std()) return data def denormalize( self, year, month, data, coords) : diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 3b2f4a8..d6de908 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -18,8 +18,8 @@ import numpy as np from functools import partial import code - -from atmorep.utils.utils import tokenize +import pdb +# from atmorep.utils.utils import tokenize #################################################################################################### def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, fields_infos) : @@ -48,76 +48,38 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, else : assert False - # # advance randomly to avoid issues with parallel data loaders that naively duplicate rngs - # delta = torch.randint( 0, 1000, (1,)).item() - # [rng.bit_generator.advance( delta) for rng in rngs] - - if cf.BERT_window : - # window size has to be multiple of two due to the variable token sizes (the size is - # however currently restricted to differ by exactly a factor of two only) - size_t = int(rngs[0].integers( 2, fields[0][3][0]+1, 1)[0] / 2.) * 2 - size_lat = int(rngs[0].integers( 2, fields[0][3][1]+1, 1)[0] / 2.) * 2 - size_lon = int(rngs[0].integers( 2, fields[0][3][2]+1, 1)[0] / 2.) * 2 - - # swap fields (idx=2) in first position for iteration and time (idx=1) before spatial coordinates - # fields_data = fields_data.permute( [3, 0, 2, 1, 4, 5]) - fields_data = fields_data.permute( [2, 0, 3, 1, 4, 5]) - fields_data_out, fields_infos_out, fields_targets = [], [], [] - fields_tokens_masked_idx, fields_tokens_masked_idx_list = [], [] - - rng_idx = 0 - for ifield, field_data in enumerate(fields_data) : - - tok_size = fields[ifield][4] - field_data = tokenize( field_data, tok_size ) - - # # cut neighborhood for current batch - # if cf.BERT_window : - # # adjust size based on token size so that one has a fixed size window in physical space - # cur_size_t = int(size_t * fields[ifield][3][0] / fields[0][3][0]) - # cur_size_lat = int(size_lat * fields[ifield][3][1] / fields[0][3][1]) - # cur_size_lon = int(size_lon * fields[ifield][3][2] / fields[0][3][2]) - # # define indices - # idx_t_s = field_data.shape[1] - cur_size_t - # idx_lat_s = field_data.shape[2] - cur_size_lat - # idx_lon_s = field_data.shape[3] - cur_size_lon - # # cut - # field_data = field_data[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] - # field_data = field_data.contiguous() - # # for token info first recover space-time shape - # token_info = token_info.reshape( list(field_data_shape[0:4]) + [token_info.shape[-1]]) - # token_info = token_info[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] - # token_info = torch.flatten( token_info, 1, -2) - # token_info = token_info.contiguous() - - # no masking for static fields or if masking rate = 0 - if fields[ifield][1][0] > 0 and fields[ifield][5][0] > 0. : - - ret = bert_f( cf, ifield, field_data, rngs[rng_idx]) - (field_data, target, tokens_masked_idx, tokens_masked_idx_list) = ret - - if target is not None : - fields_targets.append( target) - fields_tokens_masked_idx.append( tokens_masked_idx) - fields_tokens_masked_idx_list.append( tokens_masked_idx_list) - - rng_idx += 1 - - fields_data_out.append( field_data ) - fields_infos_out.append( torch.zeros( (*field_data.shape[:5], 8)) ) - - # merge - # sources[ifield] = torch.cat( sources[ifield], 1) - # token_infos[ifield] = torch.cat( token_infos[ifield], 1) - # # merge along vertical level, for target we have level, batch, ... ordering - # fields_targets[ifield] = torch.cat( fields_targets[ifield],0) \ - # if len(fields_targets[ifield]) > 0 else fields_targets[ifield] - - return (fields_data_out, fields_infos_out, fields_targets, fields_tokens_masked_idx, + rng_idx = 1 + for ifield, (field, infos) in enumerate(zip(fields_data, fields_infos)) : + for ilevel, (field_data, token_info) in enumerate(zip(field, infos)) : + + # no masking for static fields or if masking rate = 0 + if fields[ifield][1][0] > 0 and fields[ifield][5][0] > 0. : + + ret = bert_f( cf, ifield, field_data, token_info, rngs[rng_idx]) + (field_data, token_info, target, tokens_masked_idx, tokens_masked_idx_list) = ret + + if target != None : + fields_targets[ifield].append( target) + fields_tokens_masked_idx[ifield].append( tokens_masked_idx) + fields_tokens_masked_idx_list[ifield].append( tokens_masked_idx_list) + + rng_idx += 1 + + sources[ifield].append( field_data.unsqueeze(1) ) + token_infos[ifield].append( token_info ) + + # merge along vertical level + sources[ifield] = torch.cat( sources[ifield], 1) + token_infos[ifield] = torch.cat( token_infos[ifield], 1) + # merge along vertical level, for target we have level, batch, ... ordering + fields_targets[ifield] = torch.cat( fields_targets[ifield],0) \ + if len(fields_targets[ifield]) > 0 else fields_targets[ifield] + + return (sources, token_infos, fields_targets, fields_tokens_masked_idx, fields_tokens_masked_idx_list) #################################################################################################### -def prepare_batch_BERT_field( cf, ifield, source, rng) : +def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # shortcuts mr = partial( torch.nn.functional.interpolate, mode='trilinear') @@ -126,10 +88,10 @@ def prepare_batch_BERT_field( cf, ifield, source, rng) : sq = torch.squeeze usq = torch.unsqueeze cnt_nz = torch.count_nonzero - + #breakpoint() # collapse token dimensions source_shape0 = source.shape - source = torch.flatten( torch.flatten( source, 1, 4), 2, 4) + source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) # select random token in the selected space-time cube to be masked/deleted BERT_frac = cf.fields[ifield][5][0] @@ -196,11 +158,16 @@ def prepare_batch_BERT_field( cf, ifield, source, rng) : # unsqueeze(usq()) is required since channel dimension is expected temp = mr( mr( usq( source[ idx[idx_mr_cond] ].reshape( (-1,ts[0],ts[1],ts[2])), 1), mrs), ts) source[ idx[idx_mr_cond] ] = sq( fl( temp, -3, -1)) + # adjust resolution parameter in token_info + # token_info_shape = token_info.shape + # token_info = token_info.flatten( 0, 1) + # token_info[ idx[idx_mr_cond] ][-1] *= (mrs[1] + mrs[2]) / 2. #TODO: anisotropic resolution + # token_info = token_info.reshape( token_info_shape) # recover batch dimension which was flattend for easier indexing and also token dimensions source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, target, tokens_masked_idx, tokens_masked_idx_list) + return (source, token_info, target, tokens_masked_idx, tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : @@ -209,7 +176,7 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : num_tokens = source.shape[-6:-3] num_tokens_space = num_tokens[1] * num_tokens[2] idxs = (num_tokens[0]-nt) * num_tokens_space + torch.arange(nt * num_tokens_space) - + breakpoint() # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index f600eaf..bf9998e 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -289,6 +289,7 @@ def tokenize( data, token_size = [-1,-1,-1]) : if token_size[0] > -1 : data_shape = data.shape + print("data_shape", data.shape) tok_tot_t = int( data_shape[-3] / token_size[0]) tok_tot_x = int( data_shape[-2] / token_size[1]) tok_tot_y = int( data_shape[-1] / token_size[2]) From dbe1bbc28af50a4d3ff739f4c2fefab7a39f6023 Mon Sep 17 00:00:00 2001 From: iluise Date: Wed, 20 Mar 2024 16:08:58 +0100 Subject: [PATCH 04/66] working bert.py --- atmorep/core/trainer.py | 13 +++-- atmorep/datasets/multifield_data_sampler.py | 62 ++++++++++----------- atmorep/training/bert.py | 2 +- 3 files changed, 37 insertions(+), 40 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index e9b21ec..809d4e6 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -638,12 +638,13 @@ def prepare_batch( self, xin) : self.targets.append( targets[ifield].to( devs[cf.fields[ifield][1][3]], non_blocking=True )) # idxs of masked tokens - # tmi_out = [[] for _ in range(len(fields_tokens_masked_idx))] - # for i,tmi in enumerate(fields_tokens_masked_idx) : - # tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] - # self.tokens_masked_idx = tmi_out - self.tokens_masked_idx = [tmi.to(devs[cf.fields[i][1][3]], non_blocking=True) - for i,tmi in enumerate(fields_tokens_masked_idx)] + tmi_out = [[] for _ in range(len(fields_tokens_masked_idx))] + for i,tmi in enumerate(fields_tokens_masked_idx) : + tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] + self.tokens_masked_idx = tmi_out + #breakpoint() + # self.tokens_masked_idx = [tmi.to(devs[cf.fields[i][1][3]], non_blocking=True) + # for i,tmi in enumerate(fields_tokens_masked_idx)] # idxs of masked tokens per batch entry self.fields_tokens_masked_idx_list = fields_tokens_masked_idx_list diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 173c213..78f0ef8 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -26,6 +26,7 @@ import code from atmorep.utils.utils import days_until_month_in_year from atmorep.utils.utils import days_in_month +from datetime import datetime import atmorep.config.config as config @@ -182,7 +183,7 @@ def __iter__(self): data_lvl += [self.ds['data'].oindex[ idxs_t, field_idx, vl_idx]] data_t += [data_lvl] - sources, sources_infos, source_idxs = [], [], [] + sources, sources_infos, source_idxs, token_infos = [], [], [], [] lat_ran = [] lon_ran = [] for sidx in range(self.batch_size) : @@ -204,52 +205,47 @@ def __iter__(self): if self.with_source_idxs : source_idxs += [ (idxs_t, lat_ran, lon_ran) ] - source = [] - source_infos = [] # extract data # TODO: temporal window can span multiple months year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month for ifield, field_info in enumerate(self.fields): - source_lvl = [] - source_info_lvl = [] + + source_lvl, source_info_lvl, tok_info_lvl = [], [], [] tok_size = field_info[4] for ilevel, vl in enumerate(field_info[2]): #self.levels : + nf = self.normalizers[ifield][ilevel].normalize - source_data, info_data = [], [] - + source_data, info_data, tok_info = [], [], [] + for sidx in range(self.batch_size) : #normalize and tokenize source_data += [ tokenize( torch.from_numpy(nf( year, month, np.take( np.take( data_t[ifield][ilevel], lat_ran[sidx], -2), lon_ran[sidx], -1), (lat_ran[sidx], lon_ran[sidx]))), tok_size ) ] - # info_data += [ torch.Tensor(self.ds['time'][ idxs_t ], vl, - # self.lats[lat_ran[sidx]], self.lons[lon_ran[sidx]], self.res) ] - - info_data += [ [self.ds['time'][ idxs_t ], vl, - self.lats[lat_ran[sidx]], self.lons[lon_ran[sidx]], self.res] ] + + dates = self.ds['time'][ idxs_t ].astype(datetime) + dates = [(d.year, d.timetuple().tm_yday, d.hour) for d in dates] + lats = self.lats[lat_ran[sidx]] + lons = self.lons[lon_ran[sidx]] + info_data += [[[[[ year, day, hour, vl, + lat, lon, vl, self.res[0], self.res[1]] for lon in lons] for lat in lats] for (year, day, hour) in dates]] #zip(years, days, hours)]] + #store only center of the token: also in time + tok_info += [[[[[ year, day, hour, vl, + lat, lon, vl, self.res[0], self.res[1]] for lon in lons[int(tok_size[2]/2)::tok_size[2]]] for lat in lats[int(tok_size[1]/2)::tok_size[1]]] for (year, day, hour) in dates[int(tok_size[0]/2)::tok_size[0]]]] + + #level source_lvl += [torch.stack(source_data, dim = 0)] source_info_lvl += [info_data] - source += [source_lvl] - source_infos += [source_info_lvl] - - sources = self.pre_batch(source, #torch.from_numpy( np.concatenate( sources, 0)), - source_infos ) # [ source ] # [ np.expand_dims(source, 0) ] - sources_infos = [sources_infos] - - #sources += [ self.pre_batch( source, #torch.from_numpy( np.concatenate( sources, 0)), - # source_infos ) ] # [ source ] # [ np.expand_dims(source, 0) ] - #sources_infos += [source_infos] - # breakpoint() - # if self.with_source_idxs : - # source_idxs += [ (idxs_t, lat_ran, lon_ran) ] - - # # extract batch info - # sources_infos += [ [ self.ds['time'][ idxs_t ], self.levels, - # self.lats[lat_ran], self.lons[lon_ran], self.res ] ] - - # swap - # sources = self.pre_batch( sources, #torch.from_numpy( np.concatenate( sources, 0)), - # sources_infos ) + tok_info_lvl += [tok_info] + + #field + sources += [torch.stack(source_lvl, dim = 0)] #torch.Size([3, 16, 12, 6, 12, 3, 9, 9]) + sources_infos += [torch.Tensor(np.array(source_info_lvl))] # torch.Size([3, 16, 36, 54, 108, 9]) + #token_infos += [torch.Tensor(np.array(tok_info_lvl))] # torch.Size([3, 16, 12, 6, 12, 9]) + token_infos += [torch.Tensor(np.array(tok_info_lvl)).reshape(len(tok_info_lvl), len(tok_info_lvl[0]), -1, 9)] #torch.Size([3, 16, 864, 9]) + sources = self.pre_batch(sources, + token_infos ) + # TODO: implement targets target, target_info = None, None diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index d6de908..9db6159 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -176,7 +176,7 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : num_tokens = source.shape[-6:-3] num_tokens_space = num_tokens[1] * num_tokens[2] idxs = (num_tokens[0]-nt) * num_tokens_space + torch.arange(nt * num_tokens_space) - breakpoint() + # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) From 967013def84a6749ea4ae13b617c74d1f768af41 Mon Sep 17 00:00:00 2001 From: iluise Date: Wed, 20 Mar 2024 17:37:02 +0100 Subject: [PATCH 05/66] first working version with sfc data + source info --- atmorep/core/trainer.py | 24 +++++++++--- atmorep/datasets/multifield_data_sampler.py | 41 ++++++++++++--------- atmorep/datasets/normalizer_global.py | 2 +- atmorep/datasets/normalizer_local.py | 2 +- atmorep/utils/utils.py | 1 - 5 files changed, 44 insertions(+), 26 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 809d4e6..08c5671 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -25,7 +25,7 @@ import datetime from typing import TypeVar import functools - +import pdb import pandas as pd import wandb @@ -172,6 +172,7 @@ def run( self, epoch = -1) : # test at the beginning as reference self.model.load_data( NetMode.test, batch_size=cf.batch_size_test) + if cf.test_initial : cur_test_loss = self.validate( epoch, cf.BERT_strategy).cpu().numpy() test_loss = np.array( [cur_test_loss]) @@ -179,7 +180,7 @@ def run( self, epoch = -1) : # generic value based on data normalization test_loss = np.array( [1.0]) epoch += 1 - + print("after test_loss") batch_size = cf.batch_size_start - cf.batch_size_delta if cf.profile : @@ -621,6 +622,7 @@ def prepare_batch( self, xin) : # xin[0] since BERT does not have targets (sources, token_infos, targets, fields_tokens_masked_idx,fields_tokens_masked_idx_list) = xin[0] self.sources_idxs = xin[2] + self.sources_info = xin[3] # network input batch_data = [ ( sources[i].to( devs[ cf.fields[i][1][3] ], non_blocking=True), @@ -642,7 +644,6 @@ def prepare_batch( self, xin) : for i,tmi in enumerate(fields_tokens_masked_idx) : tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] self.tokens_masked_idx = tmi_out - #breakpoint() # self.tokens_masked_idx = [tmi.to(devs[cf.fields[i][1][3]], non_blocking=True) # for i,tmi in enumerate(fields_tokens_masked_idx)] @@ -677,8 +678,21 @@ def decoder_to_tail( self, idx_pred, pred) : # flatten token dimensions: remove space-time separation pred = torch.flatten( pred, 2, 3).to( dev) # extract masked token level by level - pred_masked = torch.flatten( pred, 0, 2) - pred_masked = pred_masked[ target_idx ] + #pred_masked = torch.flatten( pred, 0, 2) + # extract masked token level by level + pred_masked = [] + for lidx, level in enumerate(self.cf.fields[field_idx][2]) : + # select masked tokens, flattened along batch dimension for easier indexing and processing + pred_l = torch.flatten( pred[:,lidx], 0, 1) + pred_masked_l = pred_l[ target_idx[lidx] ] + #target_idx_l = target_idx[lidx] + pred_masked.append( pred_masked_l) + + # flatten along level dimension, for loss evaluation we effectively have level, batch, ... + # as ordering of dimensions + pred_masked = torch.cat( pred_masked, 0) + + #pred_masked = pred_masked[ target_idx ] return pred_masked diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 78f0ef8..0d81cad 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -184,12 +184,11 @@ def __iter__(self): data_t += [data_lvl] sources, sources_infos, source_idxs, token_infos = [], [], [], [] - lat_ran = [] - lon_ran = [] + lat_ran, lon_ran = [], [] + for sidx in range(self.batch_size) : idx = self.idxs_perm[bidx*self.batch_size+sidx] - # slight assymetry with offset by res/2 is required to match desired token count lat_ran += [np.where(np.logical_and(lats > idx[0]-ns_2[1]-res[0]/2.,lats < idx[0]+ns_2[1]))[0]] # handle periodicity of lon @@ -201,7 +200,7 @@ def __iter__(self): lon_ran += [np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0)] else : lon_ran += [np.where(np.logical_and( lons > il, lons < ir))[0]] - + if self.with_source_idxs : source_idxs += [ (idxs_t, lat_ran, lon_ran) ] @@ -215,7 +214,7 @@ def __iter__(self): for ilevel, vl in enumerate(field_info[2]): #self.levels : nf = self.normalizers[ifield][ilevel].normalize - source_data, info_data, tok_info = [], [], [] + source_data, tok_info = [], [] for sidx in range(self.batch_size) : #normalize and tokenize @@ -223,33 +222,39 @@ def __iter__(self): lat_ran[sidx], -2), lon_ran[sidx], -1), (lat_ran[sidx], lon_ran[sidx]))), tok_size ) ] dates = self.ds['time'][ idxs_t ].astype(datetime) - dates = [(d.year, d.timetuple().tm_yday, d.hour) for d in dates] - lats = self.lats[lat_ran[sidx]] - lons = self.lons[lon_ran[sidx]] - info_data += [[[[[ year, day, hour, vl, - lat, lon, vl, self.res[0], self.res[1]] for lon in lons] for lat in lats] for (year, day, hour) in dates]] #zip(years, days, hours)]] - #store only center of the token: also in time + #store only center of the token: + #in time we store the *last* datetime in the token, not the center + dates = [(d.year, d.timetuple().tm_yday, d.hour) for d in dates][tok_size[0]-1::tok_size[0]] + lats_sidx = self.lats[lat_ran[sidx]][int(tok_size[1]/2)::tok_size[1]] + lons_sidx = self.lons[lon_ran[sidx]][int(tok_size[2]/2)::tok_size[2]] + # info_data += [[[[[ year, day, hour, vl, + # lat, lon, vl, self.res[0]] for lon in lons] for lat in lats] for (year, day, hour) in dates]] #zip(years, days, hours)]] + tok_info += [[[[[ year, day, hour, vl, - lat, lon, vl, self.res[0], self.res[1]] for lon in lons[int(tok_size[2]/2)::tok_size[2]]] for lat in lats[int(tok_size[1]/2)::tok_size[1]]] for (year, day, hour) in dates[int(tok_size[0]/2)::tok_size[0]]]] + lat, lon, vl, self.res[0]] for lon in lons_sidx] for lat in lats_sidx] for (year, day, hour) in dates]] #level source_lvl += [torch.stack(source_data, dim = 0)] - source_info_lvl += [info_data] + # source_info_lvl += [info_data] tok_info_lvl += [tok_info] #field sources += [torch.stack(source_lvl, dim = 0)] #torch.Size([3, 16, 12, 6, 12, 3, 9, 9]) - sources_infos += [torch.Tensor(np.array(source_info_lvl))] # torch.Size([3, 16, 36, 54, 108, 9]) - #token_infos += [torch.Tensor(np.array(tok_info_lvl))] # torch.Size([3, 16, 12, 6, 12, 9]) - token_infos += [torch.Tensor(np.array(tok_info_lvl)).reshape(len(tok_info_lvl), len(tok_info_lvl[0]), -1, 9)] #torch.Size([3, 16, 864, 9]) + # sources_infos += [torch.Tensor(np.array(source_info_lvl))] # torch.Size([3, 16, 36, 54, 108, 8]) + #token_infos += [torch.Tensor(np.array(tok_info_lvl))] # torch.Size([3, 16, 12, 6, 12, 8]) + # extract batch info + sources_infos += [ [ self.ds['time'][ idxs_t ], self.levels, + self.lats[lat_ran], self.lons[lon_ran], self.res ] ] + token_infos += [torch.Tensor(np.array(tok_info_lvl)).reshape(len(tok_info_lvl), len(tok_info_lvl[0]), -1, 8)] #torch.Size([3, 16, 864, 8]) sources = self.pre_batch(sources, token_infos ) # TODO: implement targets target, target_info = None, None - - yield ( sources, (target, target_info), source_idxs ) + #this already goes back to trainer.py. + #source_info needed to remove log_validate in trainer.py + yield ( sources, (target, target_info), source_idxs, sources_infos) ################################################### def __len__(self): diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py index c5558b9..c14bd6d 100644 --- a/atmorep/datasets/normalizer_global.py +++ b/atmorep/datasets/normalizer_global.py @@ -33,7 +33,7 @@ def normalize( self, year, month, data, coords = None) : corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), self.corr_data[:,1] == float(month))) , 2:].flatten() data_temp = (data - corr_data_ym[0]) / corr_data_ym[1] - print(data_temp.mean(), data_temp.std()) + #print(data_temp.mean(), data_temp.std()) return (data - corr_data_ym[0]) / corr_data_ym[1] def denormalize( self, year, month, data, coords = None) : diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py index 32f1763..1643795 100644 --- a/atmorep/datasets/normalizer_local.py +++ b/atmorep/datasets/normalizer_local.py @@ -66,7 +66,7 @@ def normalize( self, year, month, data, coords) : data[i] = (data[i] - mean) / var else : data = (data - mean) / var - print(data.mean(), data.std()) + #print(data.mean(), data.std()) return data def denormalize( self, year, month, data, coords) : diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index bf9998e..f600eaf 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -289,7 +289,6 @@ def tokenize( data, token_size = [-1,-1,-1]) : if token_size[0] > -1 : data_shape = data.shape - print("data_shape", data.shape) tok_tot_t = int( data_shape[-3] / token_size[0]) tok_tot_x = int( data_shape[-2] / token_size[1]) tok_tot_y = int( data_shape[-1] / token_size[2]) From 490df643619cddf3cffd1f160214c0f59efaa7da Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 21 Mar 2024 14:21:20 +0100 Subject: [PATCH 06/66] simplify log_validate --- atmorep/core/atmorep_model.py | 7 +- atmorep/core/evaluate.py | 14 +-- atmorep/core/evaluator.py | 10 ++ atmorep/core/train.py | 2 +- atmorep/core/trainer.py | 115 ++++++++----------- atmorep/datasets/data_writer.py | 120 ++++++++++++-------- atmorep/datasets/multifield_data_sampler.py | 23 ++-- 7 files changed, 152 insertions(+), 139 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 05d927a..a3e90ba 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -152,7 +152,8 @@ def normalizer( self, field, vl_idx) : normalizer = self.dataset_train.datasets[fidx].normalizer elif isinstance( field, int) : - normalizer = self.dataset_train.datasets[field][vl_idx].normalizer + normalizer = self.dataset_train.normalizers[field][vl_idx] +# normalizer = self.dataset_train.datasets[field][vl_idx].normalizer else : assert False, 'invalid argument type (has to be index to cf.fields or field name)' @@ -207,11 +208,11 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.pre_batch_targets = pre_batch_targets cf = self.net.cf - self.dataset_train = MultifieldDataSampler( cf.fields, cf.levels, cf.years_train, + self.dataset_train = MultifieldDataSampler( cf.fields, cf.years_train, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_per_epoch ) - self.dataset_test = MultifieldDataSampler( cf.fields, cf.levels, cf.years_test, + self.dataset_test = MultifieldDataSampler( cf.fields, cf.years_test, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_validate, with_source_idxs = True ) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index b8178f2..cc4ce62 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -30,7 +30,7 @@ # model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence - model_id = '1jh2qvrx' # multiformer, velocity + # model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity # model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting @@ -42,16 +42,16 @@ # e.g. global_forecast where a start date can be specified # BERT masked token model - # mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} + mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} # BERT forecast mode # mode, options = 'forecast', {'forecast_num_tokens' : 1} #, 'fields[0][2]' : [123, 137], 'attention' : False } # BERT forecast with patching to obtain global forecast - mode, options = 'global_forecast', { 'fields[0][2]' : [123, 137], - 'dates' : [[2021, 2, 10, 12]], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 1, - 'attention' : False } + # mode, options = 'global_forecast', { 'fields[0][2]' : [123, 137], + # 'dates' : [[2021, 2, 10, 12]], + # 'token_overlap' : [0, 0], + # 'forecast_num_tokens' : 1, + # 'attention' : False } Evaluator.evaluate( mode, model_id, options) diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 9dccc35..083a951 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -92,6 +92,16 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : cf.num_loader_workers = cf.loader_num_workers cf.data_dir = './data/' + #backward compatibility + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = True + func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index f9bdf2f..9c41829 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -239,7 +239,7 @@ def train() : cf.write_json( wandb) cf.print() - cf.levels = [114, 123, 137] + #cf.levels = [114, 123, 137] cf.with_mixed_precision = True # cf.n_size = [36, 1*9*6, 1.*9*12] # in steps x lat_degrees x lon_degrees diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 08c5671..83b3a29 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -180,7 +180,7 @@ def run( self, epoch = -1) : # generic value based on data normalization test_loss = np.array( [1.0]) epoch += 1 - print("after test_loss") + batch_size = cf.batch_size_start - cf.batch_size_delta if cf.profile : @@ -390,13 +390,18 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] # targets if len(batch_data[1]) > 0 : - targets = [] - for target_field in batch_data[1] : - targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) + if type(batch_data[1][0][0]) is list : + targets = [batch_data[1][i][0][0] for i in range( len(batch_data[1]))] + else : + targets = batch_data[1][0] + # targets = [] + # for target_field in batch_data[1] : + # targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) # store on cpu log_sources = ( [source.detach().clone().cpu() for source in sources ], [ti.detach().clone().cpu() for ti in token_infos], - [target.detach().clone().cpu() for target in targets ], + [], + #[target.detach().clone().cpu() for target in targets ], tmis, tmis_list ) with torch.autocast(device_type='cuda',dtype=torch.float16,enabled=cf.with_mixed_precision): @@ -621,8 +626,7 @@ def prepare_batch( self, xin) : # unpack loader output # xin[0] since BERT does not have targets (sources, token_infos, targets, fields_tokens_masked_idx,fields_tokens_masked_idx_list) = xin[0] - self.sources_idxs = xin[2] - self.sources_info = xin[3] + (self.sources_idxs, self.sources_info) = xin[2] # network input batch_data = [ ( sources[i].to( devs[ cf.fields[i][1][3] ], non_blocking=True), @@ -644,9 +648,7 @@ def prepare_batch( self, xin) : for i,tmi in enumerate(fields_tokens_masked_idx) : tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] self.tokens_masked_idx = tmi_out - # self.tokens_masked_idx = [tmi.to(devs[cf.fields[i][1][3]], non_blocking=True) - # for i,tmi in enumerate(fields_tokens_masked_idx)] - + # idxs of masked tokens per batch entry self.fields_tokens_masked_idx_list = fields_tokens_masked_idx_list @@ -827,8 +829,7 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : (sources, token_infos, targets, tokens_masked_idx, tokens_masked_idx_list) = log_sources sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] - sources_dates_out, sources_lats_out, sources_lons_out = [ ], [ ], [ ] - targets_dates_out, targets_lats_out, targets_lons_out = [ ], [ ], [ ] + coords = [] for fidx, field_info in enumerate(cf.fields) : @@ -840,53 +841,42 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) tinfos = token_infos[fidx].reshape( [-1, num_levels, *num_tokens, cf.size_token_info]) res = tinfos[0,0,0,0,0][-1].item() - batch_size = tinfos.shape[0] + batch_size = len(self.sources_info) #tinfos.shape[0] sources_b = detok( sources[fidx].numpy()) - + if is_predicted : # split according to levels lens_levels = [t.shape[0] for t in tokens_masked_idx[fidx]] - targets_b = torch.split( targets[fidx], lens_levels) + #targets_b = torch.split( targets[fidx], lens_levels) preds_mu_b = torch.split( log_preds[fidx][0], lens_levels) preds_ens_b = torch.split( log_preds[fidx][2], lens_levels) # split according to batch lens_batches = [ [bv.shape[0] for bv in b] for b in tokens_masked_idx_list[fidx] ] - targets_b = [torch.split( targets_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] + #targets_b = [torch.split( targets_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] preds_mu_b = [torch.split(preds_mu_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] preds_ens_b =[torch.split(preds_ens_b[vidx],lens) for vidx,lens in enumerate(lens_batches)] # recover token shape - targets_b = [[targets_b[vidx][bidx].reshape([-1, *token_size]) - for bidx in range(batch_size)] - for vidx in range(num_levels)] + #targets_b = [[targets_b[vidx][bidx].reshape([-1, *token_size]) + # for bidx in range(batch_size)] + # for vidx in range(num_levels)] preds_mu_b = [[preds_mu_b[vidx][bidx].reshape([-1, *token_size]) for bidx in range(batch_size)] for vidx in range(num_levels)] preds_ens_b = [[preds_ens_b[vidx][bidx].reshape( [-1, cf.net_tail_num_nets, *token_size]) for bidx in range(batch_size)] for vidx in range(num_levels)] - + # for all batch items coords_b = [] - for bidx, tinfo in enumerate(tinfos) : - - # use first vertical levels since a column is considered - lats = np.arange(tinfo[0,0,0,0,4]-lat_d_h*res, tinfo[0,0,-1,0,4]+lat_d_h*res+0.001,res) - if tinfo[0,0,0,-1,5] < tinfo[0,0,0,0,5] : - lons = np.remainder( np.arange( tinfo[0,0,0,0,5] - lon_d_h*res, - 360. + tinfo[0,0,0,-1,5] + lon_d_h*res + 0.001, res), 360.) - else : - lons = np.arange(tinfo[0,0,0,0,5]-lon_d_h*res, tinfo[0,0,0,-1,5]+lon_d_h*res+0.001,res) - lons = np.remainder( lons, 360.) - - # time stamp in token_infos is at start time so needs to be advanced by token_size[0]-1 - s = utils.token_info_to_time( tinfo[0,0,0,0,:3] ) - pd.Timedelta(hours=token_size[0]-1) - e = utils.token_info_to_time( tinfo[0,-1,0,0,:3] ) - dates = pd.date_range( start=s, end=e, freq='h') - + for bidx in range(batch_size): #, tinfo in enumerate(tinfos) : + dates = self.sources_info[bidx][0] + lats = 90.- self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] + # target etc are aliasing targets_b which simplifies bookkeeping below if is_predicted : - target = [targets_b[vidx][bidx] for vidx in range(num_levels)] + # target = [targets_b[vidx][bidx] for vidx in range(num_levels)] pred_mu = [preds_mu_b[vidx][bidx] for vidx in range(num_levels)] pred_ens = [preds_ens_b[vidx][bidx] for vidx in range(num_levels)] @@ -898,11 +888,15 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : sources_b[bidx,vidx] = normalizer.denormalize( y, m, sources_b[bidx,vidx], [lats, lons]) if is_predicted : + # dates_masked = self.targets_info[bidx][vidx][0] #67,3 + # lats_masked = 90.- self.targets_info[bidx][vidx][1] #67,9 + # lons_masked = self.targets_info[bidx][vidx][2] #67,9 # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] tinfo_masked = tinfos[bidx,vidx].flatten( 0,2) tinfo_masked = tinfo_masked[idx] + lad, lod = lat_d_h*res, lon_d_h*res lats_masked, lons_masked, dates_masked = [], [], [] for t in tinfo_masked : @@ -916,50 +910,33 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : lats_masked = np.concatenate( lats_masked, 0) lons_masked = np.remainder( np.concatenate( lons_masked, 0), 360.) dates_masked = np.concatenate( dates_masked, 0) - - for ii,(t,p,e,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], + #for ii,(t,p,e,la,lo) in enumerate(zip( #target[vidx], + for ii,(p,e,la,lo) in enumerate(zip( #target[vidx], + pred_mu[vidx], pred_ens[vidx], lats_masked, lons_masked)) : - targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) + # targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) preds_mu_b[vidx][bidx][ii] = normalizer.denormalize( y, m, p, [la, lo]) preds_ens_b[vidx][bidx][ii] = normalizer.denormalize( y, m, e, [la, lo]) dates_masked_l += [ dates_masked ] - lats_masked_l += [ [90.-lat for lat in lats_masked] ] + lats_masked_l += [ [90.- lat for lat in lats_masked] ] lons_masked_l += [ lons_masked ] - dates = dates.to_pydatetime().astype( 'datetime64[s]') - - coords_b += [ [dates, 90.-lats, lons, dates_masked_l, lats_masked_l, lons_masked_l] ] - + coords_b += [ [dates, lats, lons, dates_masked_l, lats_masked_l, lons_masked_l] ] + + coords += [ coords_b ] fn = field_info[0] sources_out.append( [fn, sources_b]) - if is_predicted : - targets_out.append([fn, [[t.numpy(force=True) for t in t_v] for t_v in targets_b]]) - preds_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_mu_b]]) - ensembles_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_ens_b]]) - else : - targets_out.append( [fn, []]) - preds_out.append( [fn, []]) - ensembles_out.append( [fn, []]) - sources_dates_out.append( [c[0] for c in coords_b]) - sources_lats_out.append( [c[1] for c in coords_b]) - sources_lons_out.append( [c[2] for c in coords_b]) - if is_predicted : - targets_dates_out.append( [c[3] for c in coords_b]) - targets_lats_out.append( [c[4] for c in coords_b]) - targets_lons_out.append( [c[5] for c in coords_b]) - else : - targets_dates_out.append( [ ]) - targets_lats_out.append( [ ]) - targets_lons_out.append( [ ]) + # targets_out.append([fn, [[t.numpy(force=True) for t in t_v] for t_v in targets_b]] if is_predicted else [fn, []]) + preds_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_mu_b]] if is_predicted else [fn, []] ) + ensembles_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_ens_b]] if is_predicted else [fn, []] ) levels = [[np.array(l) for l in field[2]] for field in cf.fields] - write_BERT( cf.wandb_id, epoch, batch_idx, - levels, sources_out, - [sources_dates_out, sources_lats_out, sources_lons_out], - targets_out, [targets_dates_out, targets_lats_out, targets_lons_out], - preds_out, ensembles_out ) + write_BERT( cf.wandb_id, epoch, batch_idx, + levels, sources_out, targets_out, + preds_out, ensembles_out, coords ) + def log_attention( self, epoch, bidx, log) : '''Hook for logging: output attention maps.''' diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index 350c7b4..5e48837 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -21,6 +21,15 @@ import datetime import atmorep.config.config as config +def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): + ds_batch_item = ds_field.create_group( f'{name}={name_idx:05d}' ) + ds_batch_item.create_dataset( 'data', data=data) + ds_batch_item.create_dataset( 'ml', data=levels) + ds_batch_item.create_dataset( 'datetime', data=coords[0]) + ds_batch_item.create_dataset( 'lat', data=coords[1]) + ds_batch_item.create_dataset( 'lon', data=coords[2]) + return ds_batch_item + #################################################################################################### def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, targets, targets_coords, @@ -38,17 +47,19 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, store_source = zarr_store( fname.format( 'source')) exp_source = zarr.group(store=store_source) + for fidx, field in enumerate(sources) : ds_field = exp_source.require_group( f'{field[0]}') batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=sources_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=sources_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=sources_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, [t[bidx] for t in sources_coords] ) + # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) + # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) + # ds_batch_item.create_dataset( 'ml', data=levels) + # ds_batch_item.create_dataset( 'datetime', data=sources_coords[0][bidx]) + # ds_batch_item.create_dataset( 'lat', data=sources_coords[1][bidx]) + # ds_batch_item.create_dataset( 'lon', data=sources_coords[2][bidx]) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -58,12 +69,13 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, [t[bidx] for t in targets_coords] ) + # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) + # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) + # ds_batch_item.create_dataset( 'ml', data=levels) + # ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) + # ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) + # ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -73,12 +85,13 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, [t[bidx] for t in targets_coords] ) + # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) + # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) + # ds_batch_item.create_dataset( 'ml', data=levels) + # ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) + # ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) + # ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -88,18 +101,19 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, [t[bidx] for t in targets_coords] ) + # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) + # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) + # ds_batch_item.create_dataset( 'ml', data=levels) + # ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) + # ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) + # ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) store_ens.close() #################################################################################################### -def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, - targets, targets_coords, - preds, ensembles, +def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, + targets, # targets_coords, + preds, ensembles, coords, zarr_store_type = 'ZipStore' ) : ''' sources : num_fields x [field name , data] @@ -107,6 +121,10 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, preds, ensemble share coords with targets ''' + breakpoint() + sources_coords = [[c[:3] for c in coord_field ] for coord_field in coords] + targets_coords = [[c[3:] for c in coord_field ] for coord_field in coords] + # fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch}.zarr' fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' @@ -119,12 +137,13 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels[fidx]) - ds_batch_item.create_dataset( 'datetime', data=sources_coords[0][0][bidx]) - ds_batch_item.create_dataset( 'lat', data=sources_coords[1][0][bidx]) - ds_batch_item.create_dataset( 'lon', data=sources_coords[2][0][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] ) + # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) + # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) + # ds_batch_item.create_dataset( 'ml', data=levels[fidx]) + # ds_batch_item.create_dataset( 'datetime', data=sources_coords[fidx][bidx][0]) + # ds_batch_item.create_dataset( 'lat', data=sources_coords[fidx][bidx][1]) + # ds_batch_item.create_dataset( 'lon', data=sources_coords[fidx][bidx][2]) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -141,9 +160,9 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, ds_target_b_l = ds_target_b.require_group( f'ml={levels[fidx][vidx]}') ds_target_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) ds_target_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_target_b_l.create_dataset( 'datetime', data=targets_coords[0][fidx][bidx][vidx]) - ds_target_b_l.create_dataset( 'lat', data=targets_coords[1][fidx][bidx][vidx]) - ds_target_b_l.create_dataset( 'lon', data=targets_coords[2][fidx][bidx][vidx]) + ds_target_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) + ds_target_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) + ds_target_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -157,13 +176,14 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, sample = batch_idx * batch_size + bidx ds_pred_b = ds_pred.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_pred_b_l = ds_pred_b.create_group( f'ml={levels[fidx][vidx]}') - ds_pred_b_l.create_dataset( 'data', data - =field[1][vidx][bidx]) - ds_pred_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_pred_b_l.create_dataset( 'datetime', data=targets_coords[0][fidx][bidx][vidx]) - ds_pred_b_l.create_dataset( 'lat', data=targets_coords[1][fidx][bidx][vidx]) - ds_pred_b_l.create_dataset( 'lon', data=targets_coords[2][fidx][bidx][vidx]) + ds_pred_b_l = write_item(ds_pred_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], + [t[vidx] for t in targets_coords[fidx][bidx]], name = 'ml' ) + # ds_pred_b_l = ds_pred_b.create_group( f'ml={levels[fidx][vidx]}') + # ds_pred_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) + # ds_pred_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) + # ds_pred_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) + # ds_pred_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) + # ds_pred_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -177,12 +197,14 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, sample = batch_idx * batch_size + bidx ds_ens_b = ds_ens.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_ens_b_l = ds_ens_b.create_group( f'ml={levels[fidx][vidx]}') - ds_ens_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - ds_ens_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_ens_b_l.create_dataset( 'datetime', data=targets_coords[0][fidx][bidx][vidx]) - ds_ens_b_l.create_dataset( 'lat', data=targets_coords[1][fidx][bidx][vidx]) - ds_ens_b_l.create_dataset( 'lon', data=targets_coords[2][fidx][bidx][vidx]) + ds_ens_b_l = write_item(ds_ens_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], + [t[vidx] for t in targets_coords[fidx][bidx]], name = 'ml' ) + # ds_ens_b_l = ds_ens_b.create_group( f'ml={levels[fidx][vidx]}') + # ds_ens_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) + # ds_ens_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) + # ds_ens_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) + # ds_ens_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) + # ds_ens_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_ens.close() #################################################################################################### diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 0d81cad..9837ac0 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -38,7 +38,7 @@ class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### - def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_samples_per_epoch, + def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_per_epoch, rng_seed = None, time_sampling = 1, with_source_idxs = False, fields_targets = None, pre_batch_targets = None ) : ''' @@ -86,7 +86,7 @@ def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_sa # self.levels_idxs = np.array( [self.ds.attrs['levels'].index( ll) for ll in levels]) # self.fields_idxs = [0, 1, 2] # self.levels_idxs = [0, 1] - self.levels = levels #[123, 137] # self.ds['levels'] + # self.levels = levels #[123, 137] # self.ds['levels'] # TODO # # create (target) fields @@ -119,7 +119,7 @@ def __init__( self, fields, levels, years, batch_size, pre_batch, n_size, num_sa self.normalizers.append( []) corr_type = 'global' if len(field_info) <= 6 else field_info[6] ner = NormalizerGlobal if corr_type == 'global' else NormalizerLocal - for vl in field_info[2]: #self.levels : + for vl in field_info[2]: data_type = 'data_sfc' if vl == 0 else 'data' #surface field self.normalizers[-1] += [ ner( field_info, vl, np.array(self.ds[data_type].shape)[[0,-2,-1]]) ] @@ -200,9 +200,12 @@ def __iter__(self): lon_ran += [np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0)] else : lon_ran += [np.where(np.logical_and( lons > il, lons < ir))[0]] + + sources_infos += [ [ self.ds['time'][ idxs_t ].astype(datetime), + self.lats[lat_ran][-1], self.lons[lon_ran][-1], self.res ] ] if self.with_source_idxs : - source_idxs += [ (idxs_t, lat_ran, lon_ran) ] + source_idxs += [ (idxs_t, lat_ran[-1], lon_ran[-1]) ] # extract data # TODO: temporal window can span multiple months @@ -211,7 +214,7 @@ def __iter__(self): source_lvl, source_info_lvl, tok_info_lvl = [], [], [] tok_size = field_info[4] - for ilevel, vl in enumerate(field_info[2]): #self.levels : + for ilevel, vl in enumerate(field_info[2]): nf = self.normalizers[ifield][ilevel].normalize source_data, tok_info = [], [] @@ -242,19 +245,19 @@ def __iter__(self): sources += [torch.stack(source_lvl, dim = 0)] #torch.Size([3, 16, 12, 6, 12, 3, 9, 9]) # sources_infos += [torch.Tensor(np.array(source_info_lvl))] # torch.Size([3, 16, 36, 54, 108, 8]) #token_infos += [torch.Tensor(np.array(tok_info_lvl))] # torch.Size([3, 16, 12, 6, 12, 8]) - # extract batch info - sources_infos += [ [ self.ds['time'][ idxs_t ], self.levels, - self.lats[lat_ran], self.lons[lon_ran], self.res ] ] + # extract batch info. level info stored in cf.fields. not stored here. + token_infos += [torch.Tensor(np.array(tok_info_lvl)).reshape(len(tok_info_lvl), len(tok_info_lvl[0]), -1, 8)] #torch.Size([3, 16, 864, 8]) sources = self.pre_batch(sources, token_infos ) # TODO: implement targets - target, target_info = None, None + targets, target_info = sources, sources_infos + target_idxs = None #this already goes back to trainer.py. #source_info needed to remove log_validate in trainer.py - yield ( sources, (target, target_info), source_idxs, sources_infos) + yield ( sources, targets, (source_idxs, sources_infos), (target_idxs, target_info)) ################################################### def __len__(self): From 2f7ecc3487b12c347c0cf152d09b703268b09d5c Mon Sep 17 00:00:00 2001 From: iluise Date: Fri, 22 Mar 2024 19:09:21 +0100 Subject: [PATCH 07/66] restructure data_writer --- atmorep/core/train.py | 25 +++- atmorep/core/train_multi.py | 8 + atmorep/core/trainer.py | 158 +++++++++----------- atmorep/datasets/data_writer.py | 32 ++-- atmorep/datasets/multifield_data_sampler.py | 2 +- atmorep/training/bert.py | 7 +- 6 files changed, 121 insertions(+), 111 deletions(-) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 9c41829..0b5fe4b 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -47,7 +47,18 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : # name has changed but ensure backward compatibility if hasattr( cf, 'loader_num_workers') : cf.num_loader_workers = cf.loader_num_workers - + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = True + + cf.years_train = [2021] # list( range( 1980, 2018)) + cf.years_test = [2021] #[2018] + # any parameter in cf can be overwritten when training is continued, e.g. we can increase the # masking rate # cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ], @@ -258,10 +269,10 @@ def train() : #################################################################################################### if __name__ == '__main__': - train() + # train() -# wandb_id, epoch = '1jh2qvrx', -2 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 -# epoch_continue = epoch -# -# Trainer = Trainer_BERT -# train_continue( wandb_id, epoch, Trainer, epoch_continue) + wandb_id, epoch = '1jh2qvrx', -2 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 + epoch_continue = epoch + + Trainer = Trainer_BERT + train_continue( wandb_id, epoch, Trainer, epoch_continue) diff --git a/atmorep/core/train_multi.py b/atmorep/core/train_multi.py index fc855ac..069c425 100644 --- a/atmorep/core/train_multi.py +++ b/atmorep/core/train_multi.py @@ -45,6 +45,14 @@ def train_continue( model_id, model_epoch, Trainer, model_epoch_continue = -1) : cf.optimizer_zero = False if hasattr( cf, 'loader_num_workers') : cf.num_loader_workers = cf.loader_num_workers + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = True setup_wandb( cf.with_wandb, cf, par_rank, 'train-multi', mode='offline') diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 83b3a29..6646fdd 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -50,7 +50,7 @@ from atmorep.utils.utils import CRPS from atmorep.utils.utils import NetMode from atmorep.utils.utils import sgn_exp - +from atmorep.utils.utils import tokenize, detokenize from atmorep.datasets.data_writer import write_forecast, write_BERT, write_attention #################################################################################################### @@ -380,7 +380,6 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): self.mode_test = True # run test set evaluation - with torch.no_grad() : for it in range( self.model.len( NetMode.test)) : @@ -388,20 +387,21 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] + # targets - if len(batch_data[1]) > 0 : - if type(batch_data[1][0][0]) is list : - targets = [batch_data[1][i][0][0] for i in range( len(batch_data[1]))] - else : - targets = batch_data[1][0] + #TO-DO: implement target + # if len(batch_data[1]) > 0 : + # print("len(batch_data[1])", len(batch_data[1])) + # if type(batch_data[1][0][0]) is list : + # targets = [batch_data[1][i][0][0] for i in range( len(batch_data[1]))] + # else : + # targets = batch_data[1][0] # targets = [] # for target_field in batch_data[1] : # targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) # store on cpu log_sources = ( [source.detach().clone().cpu() for source in sources ], - [ti.detach().clone().cpu() for ti in token_infos], - [], - #[target.detach().clone().cpu() for target in targets ], + [target.detach().clone().cpu() for target in targets ], tmis, tmis_list ) with torch.autocast(device_type='cuda',dtype=torch.float16,enabled=cf.with_mixed_precision): @@ -476,7 +476,6 @@ def evaluate( self, data_idx = 0, log = True): test_len = 0 # evaluate - loss = torch.tensor( 0.) with torch.no_grad() : @@ -488,14 +487,15 @@ def evaluate( self, data_idx = 0, log = True): # keep on cpu since it will otherwise clog up GPU memory (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] # targets + print("len(batch_data[1])", len(batch_data[1])) if len(batch_data[1]) > 0 : targets = [] for target_field in batch_data[1] : targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) + # store on cpu # TODO: is this still all needed with self.sources_idx log_sources = ( [source.detach().clone().cpu() for source in sources ], - [ti.detach().clone().cpu() for ti in token_infos], [target.detach().clone().cpu() for target in targets ], tmis, tmis_list ) @@ -519,7 +519,7 @@ def evaluate( self, data_idx = 0, log = True): if cf.attention: self.log_attention( data_idx , it, [atts, - [ti.detach().clone().cpu() for ti in token_infos]]) + [ti.detach().clone().cpu() for ti in token_infos]]) # average over all nodes loss /= test_len * len(self.cf.fields_prediction) @@ -719,12 +719,11 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : # TODO, TODO: use sources_idx cf = self.cf - detok = utils.detokenize # TODO, TODO: for 6h forecast we need to iterate over predicted token slices # save source: remains identical so just save ones - (sources, token_infos, targets, _, _) = log_sources + (sources, targets, _, _) = log_sources sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] @@ -763,9 +762,9 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : for fidx, field_info in enumerate(cf.fields) : # reshape from tokens to contiguous physical field num_levels = len(field_info[2]) - source = detok( sources[fidx].cpu().detach().numpy()) + source = detokenize( sources[fidx].cpu().detach().numpy()) # recover tokenized shape - target = detok( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, + target = detokenize( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) # TODO: check that geo-coords match to general ones that have been pre-determined for bidx in range(token_infos[fidx].shape[0]) : @@ -785,11 +784,11 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : num_levels = len(field_info[2]) # predictions pred = log_preds[fidx][0].cpu().detach().numpy() - pred = detok( pred.reshape( [ num_levels, -1, + pred = detokenize( pred.reshape( [ num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) # ensemble ensemble = log_preds[fidx][2].cpu().detach().numpy().swapaxes(0,1) - ensemble = detok( ensemble.reshape( [ cf.net_tail_num_nets, num_levels, -1, + ensemble = detokenize( ensemble.reshape( [ cf.net_tail_num_nets, num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(1, 2)).swapaxes(0,1) # denormalize @@ -817,16 +816,38 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : levels, sources_out, [dates_sources, lats, lons], targets_out, [dates_targets, lats, lons], preds_out, ensembles_out ) - + ################################################### + + #helpers for BERT + def split_data(self, data, idx, idx_list, token_size): + lens_levels = [t.shape[0] for t in idx] + data_b = torch.split( data, lens_levels) + # split according to batch + lens_batches = [ [bv.shape[0] for bv in b] for b in idx_list ] + return [torch.split( data_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] + + def get_masked_data(self, data, idx, idx_list, token_size, num_levels, ensemble = False): + cf = self.cf + batch_size = len(self.sources_info) + data_b = self.split_data(data, idx, idx_list, token_size) + + # recover token shape + if ensemble: + return [[data_b[vidx][bidx].reshape([-1, cf.net_tail_num_nets, *token_size]) for bidx in range(batch_size)] + for vidx in range(num_levels)] + else: + return [[data_b[vidx][bidx].reshape([-1, *token_size]) for bidx in range(batch_size)] + for vidx in range(num_levels)] + ################################################### def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : '''Logging for BERT_strategy=BERT.''' cf = self.cf - detok = utils.detokenize + batch_size = len(self.sources_info) # save source: remains identical so just save ones - (sources, token_infos, targets, tokens_masked_idx, tokens_masked_idx_list) = log_sources + (sources, targets, tokens_masked_idx, tokens_masked_idx_list) = log_sources sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] coords = [] @@ -838,49 +859,27 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : num_levels = len(field_info[2]) num_tokens = field_info[3] token_size = field_info[4] - lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) - tinfos = token_infos[fidx].reshape( [-1, num_levels, *num_tokens, cf.size_token_info]) - res = tinfos[0,0,0,0,0][-1].item() - batch_size = len(self.sources_info) #tinfos.shape[0] - - sources_b = detok( sources[fidx].numpy()) + sources_b = detokenize( sources[fidx].numpy()) if is_predicted : - # split according to levels - lens_levels = [t.shape[0] for t in tokens_masked_idx[fidx]] - #targets_b = torch.split( targets[fidx], lens_levels) - preds_mu_b = torch.split( log_preds[fidx][0], lens_levels) - preds_ens_b = torch.split( log_preds[fidx][2], lens_levels) - # split according to batch - lens_batches = [ [bv.shape[0] for bv in b] for b in tokens_masked_idx_list[fidx] ] - #targets_b = [torch.split( targets_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] - preds_mu_b = [torch.split(preds_mu_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] - preds_ens_b =[torch.split(preds_ens_b[vidx],lens) for vidx,lens in enumerate(lens_batches)] - # recover token shape - #targets_b = [[targets_b[vidx][bidx].reshape([-1, *token_size]) - # for bidx in range(batch_size)] - # for vidx in range(num_levels)] - preds_mu_b = [[preds_mu_b[vidx][bidx].reshape([-1, *token_size]) - for bidx in range(batch_size)] - for vidx in range(num_levels)] - preds_ens_b = [[preds_ens_b[vidx][bidx].reshape( [-1, cf.net_tail_num_nets, *token_size]) - for bidx in range(batch_size)] - for vidx in range(num_levels)] - + targets_b = self.get_masked_data(targets[fidx], tokens_masked_idx[fidx], tokens_masked_idx_list[fidx], token_size, num_levels) + preds_mu_b = self.get_masked_data(log_preds[fidx][0], tokens_masked_idx[fidx], tokens_masked_idx_list[fidx], token_size, num_levels) + preds_ens_b = self.get_masked_data(log_preds[fidx][2], tokens_masked_idx[fidx], tokens_masked_idx_list[fidx], token_size, num_levels, ensemble = True) + # for all batch items coords_b = [] - for bidx in range(batch_size): #, tinfo in enumerate(tinfos) : + for bidx in range(batch_size): dates = self.sources_info[bidx][0] - lats = 90.- self.sources_info[bidx][1] + lats = 90. - self.sources_info[bidx][1] lons = self.sources_info[bidx][2] # target etc are aliasing targets_b which simplifies bookkeeping below if is_predicted : - # target = [targets_b[vidx][bidx] for vidx in range(num_levels)] - pred_mu = [preds_mu_b[vidx][bidx] for vidx in range(num_levels)] + target = [targets_b[vidx][bidx] for vidx in range(num_levels)] + pred_mu = [preds_mu_b[vidx][bidx] for vidx in range(num_levels)] pred_ens = [preds_ens_b[vidx][bidx] for vidx in range(num_levels)] - dates_masked_l, lats_masked_l, lons_masked_l = [], [], [] + coords_mskd_l = [] for vidx, _ in enumerate(field_info[2]) : normalizer = self.model.normalizer( fidx, vidx) @@ -888,47 +887,36 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : sources_b[bidx,vidx] = normalizer.denormalize( y, m, sources_b[bidx,vidx], [lats, lons]) if is_predicted : - # dates_masked = self.targets_info[bidx][vidx][0] #67,3 - # lats_masked = 90.- self.targets_info[bidx][vidx][1] #67,9 - # lons_masked = self.targets_info[bidx][vidx][2] #67,9 - # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] - tinfo_masked = tinfos[bidx,vidx].flatten( 0,2) - tinfo_masked = tinfo_masked[idx] + grid = np.array(np.meshgrid(self.sources_info[bidx][2], self.sources_info[bidx][1])) + grid = np.array(np.broadcast_to(grid, shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1) #add time dimension + + lats_mskd = tokenize(torch.Tensor(grid[0]), token_size) #treat separate to avoid errors in tokenize + lons_mskd = tokenize(torch.Tensor(grid[1]), token_size) + lats_mskd = 90.- torch.flatten( lats_mskd, 0, 2)[idx] + lons_mskd = torch.flatten( lons_mskd, 0, 2)[idx] - lad, lod = lat_d_h*res, lon_d_h*res - lats_masked, lons_masked, dates_masked = [], [], [] - for t in tinfo_masked : - - lats_masked.append( np.expand_dims( np.arange(t[4]-lad, t[4]+lad+0.001,res), 0)) - lons_masked.append( np.expand_dims( np.arange(t[5]-lod, t[5]+lod+0.001,res), 0)) - - r = pd.date_range( start=utils.token_info_to_time(t), periods=token_size[0], freq='h') - dates_masked.append( np.expand_dims(r.to_pydatetime().astype( 'datetime64[s]'), 0) ) - - lats_masked = np.concatenate( lats_masked, 0) - lons_masked = np.remainder( np.concatenate( lons_masked, 0), 360.) - dates_masked = np.concatenate( dates_masked, 0) - #for ii,(t,p,e,la,lo) in enumerate(zip( #target[vidx], - for ii,(p,e,la,lo) in enumerate(zip( #target[vidx], - pred_mu[vidx], pred_ens[vidx], - lats_masked, lons_masked)) : - # targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) + #time: idx ranges from 0->863 12x6x12 + t_idx = (idx % (token_size[0]*num_tokens[0])) #* num_tokens[0] + t_idx = np.array([np.arange(t - token_size[0], t) for t in t_idx]) #create range from t_idx-2 to t_idx + dates_mskd = self.sources_info[bidx][0][t_idx] + + for ii,(t,p,e,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], + lats_mskd, lons_mskd)) : + targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) preds_mu_b[vidx][bidx][ii] = normalizer.denormalize( y, m, p, [la, lo]) preds_ens_b[vidx][bidx][ii] = normalizer.denormalize( y, m, e, [la, lo]) - - dates_masked_l += [ dates_masked ] - lats_masked_l += [ [90.- lat for lat in lats_masked] ] - lons_masked_l += [ lons_masked ] - - coords_b += [ [dates, lats, lons, dates_masked_l, lats_masked_l, lons_masked_l] ] + + coords_mskd_l += [[dates_mskd, lats_mskd.numpy(), lons_mskd.numpy()] ] + + coords_b += [ [dates, lats, lons] + coords_mskd_l ] coords += [ coords_b ] fn = field_info[0] sources_out.append( [fn, sources_b]) - # targets_out.append([fn, [[t.numpy(force=True) for t in t_v] for t_v in targets_b]] if is_predicted else [fn, []]) + targets_out.append([fn, [[t.numpy(force=True) for t in t_v] for t_v in targets_b]] if is_predicted else [fn, []]) preds_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_mu_b]] if is_predicted else [fn, []] ) ensembles_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_ens_b]] if is_predicted else [fn, []] ) diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index 5e48837..80364a7 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -22,12 +22,12 @@ import atmorep.config.config as config def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): - ds_batch_item = ds_field.create_group( f'{name}={name_idx:05d}' ) - ds_batch_item.create_dataset( 'data', data=data) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=coords[0]) - ds_batch_item.create_dataset( 'lat', data=coords[1]) - ds_batch_item.create_dataset( 'lon', data=coords[2]) + ds_batch_item = ds_field.create_group( f'{name}={name_idx:05d}' ) + ds_batch_item.create_dataset( 'data', data=data) + ds_batch_item.create_dataset( 'ml', data=levels) + ds_batch_item.create_dataset( 'datetime', data=coords[0].astype(np.datetime64)) + ds_batch_item.create_dataset( 'lat', data=coords[1].astype(np.float32)) + ds_batch_item.create_dataset( 'lon', data=coords[2].astype(np.float32)) return ds_batch_item #################################################################################################### @@ -121,7 +121,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, preds, ensemble share coords with targets ''' - breakpoint() + # breakpoint() sources_coords = [[c[:3] for c in coord_field ] for coord_field in coords] targets_coords = [[c[3:] for c in coord_field ] for coord_field in coords] @@ -157,12 +157,13 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, sample = batch_idx * batch_size + bidx ds_target_b = ds_field.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_target_b_l = ds_target_b.require_group( f'ml={levels[fidx][vidx]}') - ds_target_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - ds_target_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_target_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) - ds_target_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) - ds_target_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) + ds_target_b_l = write_item(ds_target_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) + # ds_target_b_l = ds_target_b.require_group( f'ml={levels[fidx][vidx]}') + # ds_target_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) + # ds_target_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) + # ds_target_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) + # ds_target_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) + # ds_target_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -177,7 +178,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, ds_pred_b = ds_pred.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : ds_pred_b_l = write_item(ds_pred_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], - [t[vidx] for t in targets_coords[fidx][bidx]], name = 'ml' ) + targets_coords[fidx][bidx][vidx], name = 'ml' ) # ds_pred_b_l = ds_pred_b.create_group( f'ml={levels[fidx][vidx]}') # ds_pred_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) # ds_pred_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) @@ -198,7 +199,8 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, ds_ens_b = ds_ens.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : ds_ens_b_l = write_item(ds_ens_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], - [t[vidx] for t in targets_coords[fidx][bidx]], name = 'ml' ) + targets_coords[fidx][bidx][vidx], name = 'ml' ) + # ds_ens_b_l = ds_ens_b.create_group( f'ml={levels[fidx][vidx]}') # ds_ens_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) # ds_ens_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 9837ac0..5cda67c 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -253,7 +253,7 @@ def __iter__(self): token_infos ) # TODO: implement targets - targets, target_info = sources, sources_infos + targets, target_info = None, None target_idxs = None #this already goes back to trainer.py. #source_info needed to remove log_validate in trainer.py diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 9db6159..ed4f5dd 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -88,11 +88,11 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : sq = torch.squeeze usq = torch.unsqueeze cnt_nz = torch.count_nonzero - #breakpoint() + # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) - + # select random token in the selected space-time cube to be masked/deleted BERT_frac = cf.fields[ifield][5][0] BERT_frac_mask = cf.fields[ifield][5][1] @@ -102,6 +102,7 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : token_size = cf.fields[ifield][4] batch_dim = source.shape[0] num_tokens = source.shape[1] + # masking_ratios = rng.random( batch_dim) * BERT_frac # number of tokens masked per batch entry @@ -119,7 +120,7 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # keep masked tokens for loss computation target = source[idx].clone() - + # climatological mean of normalized data global_mean = 0. * torch.mean(source, 0) global_std = torch.std(source, 0) From 36cf3e9bbcf96dd9a1d6a2d177f7d4d2ac8c2c04 Mon Sep 17 00:00:00 2001 From: iluise Date: Fri, 22 Mar 2024 19:18:45 +0100 Subject: [PATCH 08/66] add one comment --- atmorep/core/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 6646fdd..53e20f2 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -890,7 +890,7 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] grid = np.array(np.meshgrid(self.sources_info[bidx][2], self.sources_info[bidx][1])) - grid = np.array(np.broadcast_to(grid, shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1) #add time dimension + grid = np.array(np.broadcast_to(grid, shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1) #add time dimension. only way to make tokenize work lats_mskd = tokenize(torch.Tensor(grid[0]), token_size) #treat separate to avoid errors in tokenize lons_mskd = tokenize(torch.Tensor(grid[1]), token_size) From 16e160ae7b8beb0de4a72492b77fd6536fe965ac Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 25 Mar 2024 11:40:43 +0100 Subject: [PATCH 09/66] Removed duplicate fields_tokens_masked_idx. --- atmorep/core/trainer.py | 33 +++++++++++++-------------------- atmorep/training/bert.py | 15 +++++---------- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 53e20f2..36026d8 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -386,7 +386,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): batch_data = self.model.next() if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory - (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] + (sources, token_infos, targets, tmis_list) = batch_data[0] # targets #TO-DO: implement target @@ -402,7 +402,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): # store on cpu log_sources = ( [source.detach().clone().cpu() for source in sources ], [target.detach().clone().cpu() for target in targets ], - tmis, tmis_list ) + tmis_list ) with torch.autocast(device_type='cuda',dtype=torch.float16,enabled=cf.with_mixed_precision): batch_data = self.prepare_batch( batch_data) @@ -485,7 +485,7 @@ def evaluate( self, data_idx = 0, log = True): if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory - (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] + (sources, token_infos, targets, tmis_list) = batch_data[0] # targets print("len(batch_data[1])", len(batch_data[1])) if len(batch_data[1]) > 0 : @@ -497,7 +497,7 @@ def evaluate( self, data_idx = 0, log = True): # TODO: is this still all needed with self.sources_idx log_sources = ( [source.detach().clone().cpu() for source in sources ], [target.detach().clone().cpu() for target in targets ], - tmis, tmis_list ) + tmis_list ) batch_data = self.prepare_batch( batch_data) @@ -625,7 +625,7 @@ def prepare_batch( self, xin) : # unpack loader output # xin[0] since BERT does not have targets - (sources, token_infos, targets, fields_tokens_masked_idx,fields_tokens_masked_idx_list) = xin[0] + (sources, token_infos, targets, fields_tokens_masked_idx_list) = xin[0] (self.sources_idxs, self.sources_info) = xin[2] # network input @@ -644,14 +644,12 @@ def prepare_batch( self, xin) : self.targets.append( targets[ifield].to( devs[cf.fields[ifield][1][3]], non_blocking=True )) # idxs of masked tokens - tmi_out = [[] for _ in range(len(fields_tokens_masked_idx))] - for i,tmi in enumerate(fields_tokens_masked_idx) : - tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] + tmi_out = [[] for _ in range(len(fields_tokens_masked_idx_list))] + for i,tmi in enumerate(fields_tokens_masked_idx_list) : + cdev = devs[cf.fields[i][1][3]] + tmi_out[i] = [torch.cat(tmi_l,0).to( cdev, non_blocking=True) for tmi_l in tmi] self.tokens_masked_idx = tmi_out - # idxs of masked tokens per batch entry - self.fields_tokens_masked_idx_list = fields_tokens_masked_idx_list - # learnable class token (cannot be done in the data loader since this is running in parallel) if cf.learnable_mask : for ifield, (source, _) in enumerate(batch_data) : @@ -680,22 +678,16 @@ def decoder_to_tail( self, idx_pred, pred) : # flatten token dimensions: remove space-time separation pred = torch.flatten( pred, 2, 3).to( dev) # extract masked token level by level - #pred_masked = torch.flatten( pred, 0, 2) - # extract masked token level by level pred_masked = [] for lidx, level in enumerate(self.cf.fields[field_idx][2]) : # select masked tokens, flattened along batch dimension for easier indexing and processing pred_l = torch.flatten( pred[:,lidx], 0, 1) - pred_masked_l = pred_l[ target_idx[lidx] ] - #target_idx_l = target_idx[lidx] - pred_masked.append( pred_masked_l) + pred_masked.append( pred_l[ target_idx[lidx] ]) # flatten along level dimension, for loss evaluation we effectively have level, batch, ... # as ordering of dimensions pred_masked = torch.cat( pred_masked, 0) - #pred_masked = pred_masked[ target_idx ] - return pred_masked ################################################### @@ -847,7 +839,8 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : batch_size = len(self.sources_info) # save source: remains identical so just save ones - (sources, targets, tokens_masked_idx, tokens_masked_idx_list) = log_sources + (sources, targets, tokens_masked_idx_list) = log_sources + tokens_masked_idx = torch.cat( tokens_masked_idx_list) sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] coords = [] @@ -925,7 +918,7 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : levels, sources_out, targets_out, preds_out, ensembles_out, coords ) - + ################################################### def log_attention( self, epoch, bidx, log) : '''Hook for logging: output attention maps.''' cf = self.cf diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index ed4f5dd..42f41c0 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -24,7 +24,6 @@ #################################################################################################### def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, fields_infos) : - fields_tokens_masked_idx = [[] for _ in fields_data] fields_tokens_masked_idx_list = [[] for _ in fields_data] fields_targets = [[] for _ in fields_data] sources = [[] for _ in fields_data] @@ -56,11 +55,10 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, if fields[ifield][1][0] > 0 and fields[ifield][5][0] > 0. : ret = bert_f( cf, ifield, field_data, token_info, rngs[rng_idx]) - (field_data, token_info, target, tokens_masked_idx, tokens_masked_idx_list) = ret + (field_data, token_info, target, tokens_masked_idx_list) = ret - if target != None : + if target is not None : fields_targets[ifield].append( target) - fields_tokens_masked_idx[ifield].append( tokens_masked_idx) fields_tokens_masked_idx_list[ifield].append( tokens_masked_idx_list) rng_idx += 1 @@ -75,8 +73,7 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, fields_targets[ifield] = torch.cat( fields_targets[ifield],0) \ if len(fields_targets[ifield]) > 0 else fields_targets[ifield] - return (sources, token_infos, fields_targets, fields_tokens_masked_idx, - fields_tokens_masked_idx_list) + return (sources, token_infos, fields_targets, fields_tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : @@ -111,7 +108,6 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # linear indices for masking idx = torch.cat( [tokens_masked_idx_list[i] + num_tokens * i for i in range(batch_dim)] ) - tokens_masked_idx = idx # flatten along first two dimension to simplify linear indexing (which then requires an # easily computable row offset) @@ -168,7 +164,7 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # recover batch dimension which was flattend for easier indexing and also token dimensions source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx, tokens_masked_idx_list) + return (source, token_info, target, tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : @@ -185,7 +181,6 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : # linear indices for masking num_tokens = source.shape[1] idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) - tokens_masked_idx = idx source_shape = source.shape # flatten along first two dimension to simplify linear indexing (which then requires an @@ -202,7 +197,7 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : # recover batch dimension which was flattend for easier indexing source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx, idxs) + return (source, token_info, target, idxs) #################################################################################################### def prepare_batch_BERT_temporal_field( cf, ifield, source, token_info, rng) : From 27e725b11abbd82de818850cc89288feee07ab3b Mon Sep 17 00:00:00 2001 From: iluise Date: Mon, 25 Mar 2024 11:43:21 +0100 Subject: [PATCH 10/66] modify log_attention --- atmorep/core/evaluate.py | 4 +- atmorep/core/trainer.py | 91 +++++++++++++++------------------------- 2 files changed, 36 insertions(+), 59 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index cc4ce62..12b7dc9 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -42,10 +42,10 @@ # e.g. global_forecast where a start date can be specified # BERT masked token model - mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} + #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} # BERT forecast mode - # mode, options = 'forecast', {'forecast_num_tokens' : 1} #, 'fields[0][2]' : [123, 137], 'attention' : False } + mode, options = 'forecast', {'forecast_num_tokens' : 1} #, 'fields[0][2]' : [123, 137], 'attention' : False } # BERT forecast with patching to obtain global forecast # mode, options = 'global_forecast', { 'fields[0][2]' : [123, 137], diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 53e20f2..eb5c505 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -431,8 +431,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): self.log_validate( epoch, it, log_sources, log_preds) if cf.attention: - self.log_attention( epoch, it, [atts, - [ti.detach().clone().cpu() for ti in token_infos]]) + self.log_attention( epoch, it, atts) # average over all nodes total_loss /= test_len * len(self.cf.fields_prediction) @@ -518,8 +517,7 @@ def evaluate( self, data_idx = 0, log = True): self.log_validate( data_idx, it, log_sources, preds) if cf.attention: - self.log_attention( data_idx , it, [atts, - [ti.detach().clone().cpu() for ti in token_infos]]) + self.log_attention( data_idx , it, atts) # average over all nodes loss /= test_len * len(self.cf.fields_prediction) @@ -734,27 +732,28 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : num_tokens = cf.fields[0][3] token_size = cf.fields[0][4] - lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) - lats, lons = [ ], [ ] - for tinfo in token_infos[0] : - lat_min, lat_max = tinfo[0][4], tinfo[ num_tokens[1]*num_tokens[2]-1 ][4] - lon_min, lon_max = tinfo[0][5], tinfo[ num_tokens[1]*num_tokens[2]-1 ][5] - res = tinfo[0][-1] - lat = torch.arange( lat_min - lat_d_h*res, lat_max + lat_d_h*res + 0.001, res) - if lon_max < lon_min : - lon = torch.arange( lon_min - lon_d_h*res, 360. + lon_max + lon_d_h*res + 0.001, res) - else : - lon = torch.arange( lon_min - lon_d_h*res, lon_max + lon_d_h*res + 0.001, res) - lats.append( lat.numpy()) - lons.append( torch.remainder( lon, 360.).numpy()) + # lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) + # lats, lons = [ ], [ ] + + # for tinfo in token_infos[0] : + # lat_min, lat_max = tinfo[0][4], tinfo[ num_tokens[1]*num_tokens[2]-1 ][4] + # lon_min, lon_max = tinfo[0][5], tinfo[ num_tokens[1]*num_tokens[2]-1 ][5] + # res = tinfo[0][-1] + # lat = torch.arange( lat_min - lat_d_h*res, lat_max + lat_d_h*res + 0.001, res) + # if lon_max < lon_min : + # lon = torch.arange( lon_min - lon_d_h*res, 360. + lon_max + lon_d_h*res + 0.001, res) + # else : + # lon = torch.arange( lon_min - lon_d_h*res, lon_max + lon_d_h*res + 0.001, res) + # lats.append( lat.numpy()) + # lons.append( torch.remainder( lon, 360.).numpy()) # check that last token (bottom right corner) has the expected coords # assert np.allclose( ) # extract dates for each token entry, constant for each batch and field - dates_t = [] - for b_token_infos in token_infos[0] : - dates_t.append(utils.token_info_to_time(b_token_infos[0])-pd.Timedelta(hours=token_size[0]-1)) + # dates_t = [] + # for b_token_infos in token_infos[0] : + # dates_t.append(utils.token_info_to_time(b_token_infos[0])-pd.Timedelta(hours=token_size[0]-1)) # TODO: check that last token matches first one @@ -767,7 +766,10 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : target = detokenize( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) # TODO: check that geo-coords match to general ones that have been pre-determined - for bidx in range(token_infos[fidx].shape[0]) : + for bidx in range(batch_size): + dates = self.sources_info[bidx][0] + lats = 90. - self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] for vidx, _ in enumerate(field_info[2]) : denormalize = self.model.normalizer( fidx, vidx).denormalize date, coords = dates_t[bidx], [lats[bidx], lons[bidx]] @@ -926,49 +928,24 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : preds_out, ensembles_out, coords ) - def log_attention( self, epoch, bidx, log) : + def log_attention( self, epoch, bidx, attention) : '''Hook for logging: output attention maps.''' cf = self.cf - attention, token_infos = log - attn_dates_out, attn_lats_out, attn_lons_out = [ ], [ ], [ ] attn_out = [] for fidx, field_info in enumerate(cf.fields) : - # reconstruct coordinates - is_predicted = fidx in self.fields_prediction_idx - num_levels = len(field_info[2]) - num_tokens = field_info[3] - token_size = field_info[4] - lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) - tinfos = token_infos[fidx].reshape( [-1, num_levels, *num_tokens, cf.size_token_info]) + + # coordinates coords_b = [] - - for tinfo in tinfos : - # use first vertical levels since a column is considered - res = tinfo[0,0,0,0,-1] - lats = np.arange(tinfo[0,0,0,0,4]-lat_d_h*res, tinfo[0,0,-1,0,4]+lat_d_h*res+0.001,res*token_size[1]) - if tinfo[0,0,0,-1,5] < tinfo[0,0,0,0,5] : - lons = np.remainder( np.arange( tinfo[0,0,0,0,5] - lon_d_h*res, - 360. + tinfo[0,0,0,-1,5] + lon_d_h*res + 0.001, res*token_size[2]), 360.) - else : - lons = np.arange(tinfo[0,0,0,0,5]-lon_d_h*res, tinfo[0,0,0,-1,5]+lon_d_h*res+0.001,res*token_size[2]) - - lats = [90.-lat for lat in lats] - lons = np.remainder( lons, 360.) - - dates = np.array([(utils.token_info_to_time(tinfo[0,t,0,0,:3])) for t in range(tinfo.shape[1])], dtype='datetime64[s]') + for bidx in range(batch_size): + dates = self.sources_info[bidx][0] + lats = 90. - self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] coords_b += [ [dates, lats, lons] ] - if is_predicted: - attn_out.append([field_info[0], attention[fidx]]) - attn_dates_out.append([c[0] for c in coords_b]) - attn_lats_out.append( [c[1] for c in coords_b]) - attn_lons_out.append( [c[2] for c in coords_b]) - else: - attn_dates_out.append( [] ) - attn_lats_out.append( [] ) - attn_lons_out.append( [] ) - + is_predicted = fidx in self.fields_prediction_idx + attn_out.append([field_info[0], attention[fidx]] if is_predicted else [fn, []]) + levels = [[np.array(l) for l in field[2]] for field in cf.fields] write_attention(cf.wandb_id, epoch, - bidx, levels, attn_out, [attn_dates_out,attn_lats_out,attn_lons_out]) + bidx, levels, attn_out, coords_b ) \ No newline at end of file From 5886b3c64b65ca956d76f14bce5e5fae184ff8e3 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 25 Mar 2024 13:16:18 +0100 Subject: [PATCH 11/66] Fixed tokenize for 2D input (but not tested or used). --- atmorep/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index f600eaf..9fc4dd9 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -305,7 +305,7 @@ def tokenize( data, token_size = [-1,-1,-1]) : t2 = torch.reshape( data, (tok_tot_t, token_size[0], tok_tot_x, token_size[1], tok_tot_y, token_size[2])) data_tokenized = torch.transpose(torch.transpose( torch.transpose( t2, 4, 3), 2, 1), 3, 2) elif 2 == len(data_shape) : - t2 = torch.reshape( t1, (tok_tot_x, token_size[0], tok_tot_y, token_size[1])) + t2 = torch.reshape( data, (tok_tot_x, token_size[0], tok_tot_y, token_size[1])) data_tokenized = torch.transpose( t2, 1, 2) else : assert False From 1c65c7a4bc4d7ddd5c0c818703ed0e08a9fa87e2 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 25 Mar 2024 13:17:15 +0100 Subject: [PATCH 12/66] Cleaned up some details in log_validate_BERT(). --- atmorep/core/trainer.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 36026d8..81e8c75 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -809,27 +809,29 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : targets_out, [dates_targets, lats, lons], preds_out, ensembles_out ) ################################################### - #helpers for BERT - def split_data(self, data, idx, idx_list, token_size): - lens_levels = [t.shape[0] for t in idx] + + def split_data(self, data, idx_list, token_size) : + lens_batches = [[len(t) for t in tt] for tt in idx_list] + lens_levels = [torch.tensor( tt).sum() for tt in lens_batches] data_b = torch.split( data, lens_levels) # split according to batch - lens_batches = [ [bv.shape[0] for bv in b] for b in idx_list ] return [torch.split( data_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] - def get_masked_data(self, data, idx, idx_list, token_size, num_levels, ensemble = False): + def get_masked_data(self, data, idx_list, token_size, num_levels, ensemble = False) : + cf = self.cf batch_size = len(self.sources_info) - data_b = self.split_data(data, idx, idx_list, token_size) + data_b = self.split_data(data, idx_list, token_size) # recover token shape if ensemble: - return [[data_b[vidx][bidx].reshape([-1, cf.net_tail_num_nets, *token_size]) for bidx in range(batch_size)] + return [[data_b[vidx][bidx].reshape([-1, cf.net_tail_num_nets, *token_size]) + for bidx in range(batch_size)] for vidx in range(num_levels)] else: return [[data_b[vidx][bidx].reshape([-1, *token_size]) for bidx in range(batch_size)] - for vidx in range(num_levels)] + for vidx in range(num_levels)] ################################################### def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : @@ -840,7 +842,6 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : # save source: remains identical so just save ones (sources, targets, tokens_masked_idx_list) = log_sources - tokens_masked_idx = torch.cat( tokens_masked_idx_list) sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] coords = [] @@ -855,9 +856,12 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : sources_b = detokenize( sources[fidx].numpy()) if is_predicted : - targets_b = self.get_masked_data(targets[fidx], tokens_masked_idx[fidx], tokens_masked_idx_list[fidx], token_size, num_levels) - preds_mu_b = self.get_masked_data(log_preds[fidx][0], tokens_masked_idx[fidx], tokens_masked_idx_list[fidx], token_size, num_levels) - preds_ens_b = self.get_masked_data(log_preds[fidx][2], tokens_masked_idx[fidx], tokens_masked_idx_list[fidx], token_size, num_levels, ensemble = True) + targets_b = self.get_masked_data( targets[fidx], tokens_masked_idx_list[fidx], + token_size, num_levels) + preds_mu_b = self.get_masked_data( log_preds[fidx][0], tokens_masked_idx_list[fidx], + token_size, num_levels) + preds_ens_b = self.get_masked_data( log_preds[fidx][2], tokens_masked_idx_list[fidx], + token_size, num_levels, ensemble = True) # for all batch items coords_b = [] @@ -882,13 +886,14 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : if is_predicted : # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] - grid = np.array(np.meshgrid(self.sources_info[bidx][2], self.sources_info[bidx][1])) - grid = np.array(np.broadcast_to(grid, shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1) #add time dimension. only way to make tokenize work + + grid = np.array( np.meshgrid( self.sources_info[bidx][2], self.sources_info[bidx][1])) + # recover time dimension since idx assumes the full space-time cube + grid = torch.from_numpy( np.array( np.broadcast_to( grid, + shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) - lats_mskd = tokenize(torch.Tensor(grid[0]), token_size) #treat separate to avoid errors in tokenize - lons_mskd = tokenize(torch.Tensor(grid[1]), token_size) - lats_mskd = 90.- torch.flatten( lats_mskd, 0, 2)[idx] - lons_mskd = torch.flatten( lons_mskd, 0, 2)[idx] + lats_mskd = 90. - tokenize( grid[0], token_size).flatten( 0, 2)[ idx ] + lons_mskd = tokenize( grid[1], token_size).flatten( 0, 2)[ idx ] #time: idx ranges from 0->863 12x6x12 t_idx = (idx % (token_size[0]*num_tokens[0])) #* num_tokens[0] From 67e9a7eed8911087ca2427dbbfee68bcd4c3bad2 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 28 Mar 2024 16:18:46 +0100 Subject: [PATCH 13/66] fix loss --- atmorep/core/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 81e8c75..f99ff87 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -584,11 +584,13 @@ def loss( self, preds, batch_idx = 0) : losses['crps'].append( crps_loss) loss = torch.tensor( 0., device=self.device_out) + tot_weight = torch.tensor( 0., device=self.device_out) for key in losses : # print( 'LOSS : {} :: {}'.format( key, losses[key])) for ifield, val in enumerate(losses[key]) : loss += self.loss_weights[ifield] * val.to( self.device_out) - loss /= len(self.cf.fields_prediction) * len( self.cf.losses) + tot_weight += self.loss_weights[ifield] + loss /= tot_weight mse_loss = mse_loss_total / len(self.cf.fields_prediction) return loss, mse_loss, losses From 092478c531132ee0c342ba18fd368dd785e4a201 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Apr 2024 14:06:47 +0200 Subject: [PATCH 14/66] Re-enabled standard data loaders. --- atmorep/core/atmorep_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index a3e90ba..6cab75c 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -164,10 +164,12 @@ def normalizer( self, field, vl_idx) : def mode( self, mode : NetMode) : if mode == NetMode.train : - self.data_loader_iter = iter(self.dataset_train) #iter(self.data_loader_train) + self.data_loader_iter = iter(self.data_loader_train) + # self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : - self.data_loader_iter = iter(self.dataset_test) #iter(self.data_loader_test) + self.data_loader_iter = iter(self.data_loader_test) + # self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False From 5b6641e30d8993142bfcdfde5d7347a1fd18523d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Apr 2024 14:07:31 +0200 Subject: [PATCH 15/66] Removed stale code. --- atmorep/core/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index f99ff87..53569f5 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -246,6 +246,7 @@ def train( self, epoch): self.optimizer.zero_grad() for batch_idx in range( model.len( NetMode.train)) : + batch_data = self.model.next() with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=cf.with_mixed_precision): @@ -561,7 +562,6 @@ def loss( self, preds, batch_idx = 0) : loss_en = torch.tensor( 0., device=target.device) for en in torch.transpose( pred[2], 1, 0) : loss_en += self.MSELoss( en, target = target) - # losses['mse_ensemble'].append( 50. * loss_en / pred[2].shape[1]) losses['mse_ensemble'].append( loss_en / pred[2].shape[1]) # Generalized cross entroy loss for continuous distributions From 6a2c25ca60e9bdb61c1b54150a1794b35edd756f Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Apr 2024 14:10:26 +0200 Subject: [PATCH 16/66] Adapted forecast to only having tokens_mask_idx_list --- atmorep/training/bert.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 42f41c0..55813eb 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -180,7 +180,8 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : # linear indices for masking num_tokens = source.shape[1] - idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) + tokens_masked_idx_list = [idxs + num_tokens * i for i in range( source.shape[0] )] + idx = torch.cat( tokens_masked_idx_list) source_shape = source.shape # flatten along first two dimension to simplify linear indexing (which then requires an @@ -197,7 +198,7 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : # recover batch dimension which was flattend for easier indexing source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, idxs) + return (source, token_info, target, tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_temporal_field( cf, ifield, source, token_info, rng) : From b7f380af9af8b57a937649cc67602f74d481f7d5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Apr 2024 14:11:13 +0200 Subject: [PATCH 17/66] Various fixes and changing, in particular enabling again global forecast. --- atmorep/datasets/multifield_data_sampler.py | 207 +++++++++++++------- 1 file changed, 136 insertions(+), 71 deletions(-) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 5cda67c..35266a3 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -138,7 +138,7 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe def shuffle( self) : rng = self.rng - self.idxs_perm_t = rng.permutation( self.idxs_years)[:(self.num_samples // self.batch_size)] + self.idxs_perm_t = rng.permutation( self.idxs_years)[ : self.num_samples] lats = rng.random(self.num_samples) * (self.range_lat[1] - self.range_lat[0]) +self.range_lat[0] lons = rng.random(self.num_samples) * (self.range_lon[1] - self.range_lon[0]) +self.range_lon[0] @@ -167,98 +167,163 @@ def __iter__(self): for bidx in range( iter_start, iter_end) : - idx = self.idxs_perm_t[bidx] - idxs_t = list(np.arange( idx-n_size[0]*ts, idx, ts, dtype=np.int64)) - data_t = [] - - for _, field_info in enumerate(self.fields) : - data_lvl = [] - for vl in field_info[2]: - if vl == 0: #surface level - field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) - data_lvl += [self.ds['data_sfc'].oindex[ idxs_t, field_idx]] - else: - field_idx = self.ds.attrs['fields'].index( field_info[0]) - vl_idx = self.ds.attrs['levels'].index(vl) - data_lvl += [self.ds['data'].oindex[ idxs_t, field_idx, vl_idx]] - data_t += [data_lvl] - - sources, sources_infos, source_idxs, token_infos = [], [], [], [] - lat_ran, lon_ran = [], [] + sources, token_infos = [[] for _ in self.fields], [[] for _ in self.fields] + sources_infos, source_idxs = [], [] for sidx in range(self.batch_size) : + i_bidx = self.idxs_perm_t[bidx] + idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) + idx = self.idxs_perm[bidx*self.batch_size+sidx] - # slight assymetry with offset by res/2 is required to match desired token count - lat_ran += [np.where(np.logical_and(lats > idx[0]-ns_2[1]-res[0]/2.,lats < idx[0]+ns_2[1]))[0]] + # slight asymetry with offset by res/2 is required to match desired token count + lat_ran = np.where(np.logical_and(lats>idx[0]-ns_2[1]-res[0]/2.,lats 360.) il, ir = (idx[1]-ns_2[2]-res[1]/2., idx[1]+ns_2[2]) if il < 0. : - lon_ran += [np.concatenate( [np.where( lons > il+360)[0], np.where(lons < ir)[0]], 0)] + lon_ran = np.concatenate( [np.where( lons > il+360)[0], np.where(lons < ir)[0]], 0) elif ir > 360. : - lon_ran += [np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0)] + lon_ran = np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0) else : - lon_ran += [np.where(np.logical_and( lons > il, lons < ir))[0]] + lon_ran = np.where(np.logical_and( lons > il, lons < ir))[0] sources_infos += [ [ self.ds['time'][ idxs_t ].astype(datetime), - self.lats[lat_ran][-1], self.lons[lon_ran][-1], self.res ] ] + self.lats[lat_ran], self.lons[lon_ran], self.res ] ] if self.with_source_idxs : - source_idxs += [ (idxs_t, lat_ran[-1], lon_ran[-1]) ] - - # extract data - # TODO: temporal window can span multiple months - year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month - for ifield, field_info in enumerate(self.fields): - - source_lvl, source_info_lvl, tok_info_lvl = [], [], [] - tok_size = field_info[4] - for ilevel, vl in enumerate(field_info[2]): - - nf = self.normalizers[ifield][ilevel].normalize - source_data, tok_info = [], [] - - for sidx in range(self.batch_size) : - #normalize and tokenize - source_data += [ tokenize( torch.from_numpy(nf( year, month, np.take( np.take( data_t[ifield][ilevel], - lat_ran[sidx], -2), lon_ran[sidx], -1), (lat_ran[sidx], lon_ran[sidx]))), tok_size ) ] + source_idxs += [ (idxs_t, lat_ran, lon_ran) ] + + # extract data + year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month + for ifield, field_info in enumerate(self.fields): + + source_lvl, source_info_lvl, tok_info_lvl = [], [], [] + tok_size = field_info[4] + for ilevel, vl in enumerate(field_info[2]): + + if vl == 0 : #surface level + field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) + data_t = self.ds['data_sfc'].oindex[ idxs_t, field_idx] + else : + field_idx = self.ds.attrs['fields'].index( field_info[0]) + vl_idx = self.ds.attrs['levels'].index(vl) + data_t = self.ds['data'].oindex[ idxs_t, field_idx, vl_idx] + + nf = self.normalizers[ifield][ilevel].normalize + source_data, tok_info = [], [] + + # extract data, normalize and tokenize + cdata = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) + cdata = nf( year, month, cdata, (lat_ran, lon_ran) ) + source_data = tokenize( torch.from_numpy( cdata), tok_size ) + # token_infos uses center of the token: *last* datetime and center in space dates = self.ds['time'][ idxs_t ].astype(datetime) - #store only center of the token: - #in time we store the *last* datetime in the token, not the center - dates = [(d.year, d.timetuple().tm_yday, d.hour) for d in dates][tok_size[0]-1::tok_size[0]] - lats_sidx = self.lats[lat_ran[sidx]][int(tok_size[1]/2)::tok_size[1]] - lons_sidx = self.lons[lon_ran[sidx]][int(tok_size[2]/2)::tok_size[2]] - # info_data += [[[[[ year, day, hour, vl, - # lat, lon, vl, self.res[0]] for lon in lons] for lat in lats] for (year, day, hour) in dates]] #zip(years, days, hours)]] - - tok_info += [[[[[ year, day, hour, vl, - lat, lon, vl, self.res[0]] for lon in lons_sidx] for lat in lats_sidx] for (year, day, hour) in dates]] - - #level - source_lvl += [torch.stack(source_data, dim = 0)] - # source_info_lvl += [info_data] - tok_info_lvl += [tok_info] - - #field - sources += [torch.stack(source_lvl, dim = 0)] #torch.Size([3, 16, 12, 6, 12, 3, 9, 9]) - # sources_infos += [torch.Tensor(np.array(source_info_lvl))] # torch.Size([3, 16, 36, 54, 108, 8]) - #token_infos += [torch.Tensor(np.array(tok_info_lvl))] # torch.Size([3, 16, 12, 6, 12, 8]) - # extract batch info. level info stored in cf.fields. not stored here. - - token_infos += [torch.Tensor(np.array(tok_info_lvl)).reshape(len(tok_info_lvl), len(tok_info_lvl[0]), -1, 8)] #torch.Size([3, 16, 864, 8]) + cdates = dates[tok_size[0]-1::tok_size[0]] + dates = [(d.year, d.timetuple().tm_yday, d.hour) for d in cdates] + lats_sidx = self.lats[lat_ran][ tok_size[1]//2 :: tok_size[1] ] + lons_sidx = self.lons[lon_ran][ tok_size[2]//2 :: tok_size[2] ] + # tensor product for token_infos + tok_info += [[[[[ year, day, hour, vl, lat, lon, vl, self.res[0]] for lon in lons_sidx] + for lat in lats_sidx] + for (year, day, hour) in dates]] + + source_lvl += [ source_data ] + tok_info_lvl += [ torch.tensor(tok_info, dtype=torch.float32).flatten( 1, -2)] + + sources[ifield] += [ torch.stack(source_lvl, 0) ] + token_infos[ifield] += [ torch.stack(tok_info_lvl, 0) ] - sources = self.pre_batch(sources, - token_infos ) + # concatenate batches + sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources] + token_infos = [torch.stack(tis_field).transpose(1,0) for tis_field in token_infos] + sources = self.pre_batch( sources, token_infos ) - # TODO: implement targets + # TODO: implement (only required when prediction target comes from different data stream) targets, target_info = None, None target_idxs = None - #this already goes back to trainer.py. - #source_info needed to remove log_validate in trainer.py + yield ( sources, targets, (source_idxs, sources_infos), (target_idxs, target_info)) + ################################################### + def set_data( self, times_pos, batch_size = None) : + ''' + times_pos = np.array( [ [year, month, day, hour, lat, lon], ...] ) + - lat \in [90,-90] = [90N, 90S] + - lon \in [0,360] + - (year,month) pairs should be a limited number since all data for these is loaded + ''' + + # generate all the data + self.idxs_perm = np.zeros( (len(times_pos), 2)) + self.idxs_perm_t = [] + for idx, item in enumerate( times_pos) : + + assert item[2] >= 1 and item[2] <= 31 + assert item[3] >= 0 and item[3] < int(24 / self.time_sampling) + assert item[4] >= -90. and item[4] <= 90. + + tstamp = pd.to_datetime( f'{item[0]}-{item[1]}-{item[2]}-{item[3]}', format='%Y-%m-%d-%H') + self.idxs_perm_t += [ np.where( self.times == tstamp) ] + + # work with mathematical lat coordinates from here on + self.idxs_perm[idx] = np.array( [90. - item[4], item[5]]) + + ################################################### + def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : + ''' generate patch/token positions for global grid ''' + + token_overlap = torch.tensor( token_overlap).to(torch.int64) + + # assumed that sanity checking that field data is consistent has been done + ifield = 0 + field = self.fields[ifield] + + res = self.res + side_len = torch.tensor( [field[3][1] * field[4][1]*res[0], field[3][2] * field[4][2]*res[1]] ) + overlap =torch.tensor([token_overlap[0]*field[4][1]*res[0],token_overlap[1]*field[4][2]*res[1]]) + side_len_2 = side_len / 2. + assert all( overlap <= side_len_2), 'token_overlap too large for #tokens, reduce if possible' + + # generate tiles + times_pos = [] + for ctime in times : + + lat = side_len_2[0].item() + num_tiles_lat = 0 + while (lat + side_len_2[0].item()) < 180. : + num_tiles_lat += 1 + lon = side_len_2[1].item() - overlap[1].item()/2. + num_tiles_lon = 0 + while (lon - side_len_2[1]) < 360. : + times_pos += [[*ctime, -lat + 90., np.mod(lon,360.) ]] + lon += side_len[1].item() - overlap[1].item() + num_tiles_lon += 1 + lat += side_len[0].item() - overlap[0].item() + + # add one additional row if no perfect tiling (sphere is toric in longitude so no special + # handling necessary but not in latitude) + # the added row is such that it goes exaclty down to the South pole and the offset North-wards + # is computed based on this + lat -= side_len[0] - overlap[0] + if lat - side_len_2[0] < 180. : + num_tiles_lat += 1 + lat = 180. - side_len_2[0].item() + res[0] + lon = side_len_2[1].item() - overlap[1].item()/2. + while (lon - side_len_2[1]) < 360. : + times_pos += [[*ctime, -lat + 90., np.mod(lon,360.) ]] + lon += side_len[1].item() - overlap[1].item() + + # adjust batch size if necessary so that the evaluations split up across batches of equal size + batch_size = num_tiles_lon + + print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) + + print( f'\n\n{times_pos[-1]}\n\n', flush=True) + + self.set_data( times_pos, batch_size) + ################################################### def __len__(self): return self.num_samples // self.batch_size From 4c1f8ca16a4fb06fde8e890a9d201c3cc38fa73e Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 8 Apr 2024 14:46:19 +0200 Subject: [PATCH 18/66] Changed set_global to numpy arrays for consistency. --- atmorep/datasets/multifield_data_sampler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 68b5c28..b2faceb 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -130,6 +130,7 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe ################################################### def shuffle( self) : + rng = self.rng self.idxs_perm_t = rng.permutation( self.idxs_years)[ : self.num_samples] @@ -149,7 +150,7 @@ def __iter__(self): # TODO: if we keep this then we should remove the rng_seed argument for the constuctor #self.rng = np.random.default_rng() #TODO: move shuffle outside iter to avoid param overwriting in global_forecast!!! NB. BERT does not work without shuffle!! - #self.shuffle() + self.shuffle() lats, lons = self.lats, self.lons #fields_idxs, levels_idxs = self.fields_idxs, self.levels_idxs @@ -269,15 +270,16 @@ def set_data( self, times_pos, batch_size = None) : ################################################### def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : ''' generate patch/token positions for global grid ''' - token_overlap = torch.tensor( token_overlap).to(torch.int64) + + token_overlap = np.array( token_overlap).astype(np.int64) # assumed that sanity checking that field data is consistent has been done ifield = 0 field = self.fields[ifield] res = self.res - side_len = torch.tensor( [field[3][1] * field[4][1]*res[0], field[3][2] * field[4][2]*res[1]] ) - overlap =torch.tensor([token_overlap[0]*field[4][1]*res[0],token_overlap[1]*field[4][2]*res[1]]) + side_len = np.array( [field[3][1] * field[4][1]*res[0], field[3][2] * field[4][2]*res[1]] ) + overlap = np.array([token_overlap[0]*field[4][1]*res[0],token_overlap[1]*field[4][2]*res[1]]) side_len_2 = side_len / 2. assert all( overlap <= side_len_2), 'token_overlap too large for #tokens, reduce if possible' @@ -315,8 +317,6 @@ def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) - print( f'\n\n{times_pos[-1]}\n\n', flush=True) - self.set_data( times_pos, batch_size) ################################################### From 1b9e20dce8bd066b45fd4eb47ea204b2ad293619 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 8 Apr 2024 14:46:56 +0200 Subject: [PATCH 19/66] Fixed bug in returned idx_masked for BERT. --- atmorep/training/bert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 55813eb..0f47223 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -107,7 +107,8 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : tokens_masked_idx_list = [ torch.tensor(rng.permutation(num_tokens)[:nms]) for nms in nums_masked] # linear indices for masking - idx = torch.cat( [tokens_masked_idx_list[i] + num_tokens * i for i in range(batch_dim)] ) + tokens_masked_idx_list = [tokens_masked_idx_list[i] + num_tokens * i for i in range(batch_dim)] + idx = torch.cat( tokens_masked_idx_list) # flatten along first two dimension to simplify linear indexing (which then requires an # easily computable row offset) From 7c6413f2abf4e3a6a9189cfa860f81d5fe9d03dc Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 8 Apr 2024 14:47:17 +0200 Subject: [PATCH 20/66] - Fixed issues in log_BERT due to bug fixing fo masked_idx. - Minor clean up here and there. --- atmorep/core/trainer.py | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 897ca48..a567501 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -18,7 +18,6 @@ import torchinfo import numpy as np import code -# code.interact(local=locals()) from pathlib import Path import os @@ -640,10 +639,10 @@ def prepare_batch( self, xin) : self.targets.append( targets[ifield].to( devs[cf.fields[ifield][1][3]], non_blocking=True )) # idxs of masked tokens - tmi_out = [[] for _ in range(len(fields_tokens_masked_idx_list))] + tmi_out = [ ] for i,tmi in enumerate(fields_tokens_masked_idx_list) : cdev = devs[cf.fields[i][1][3]] - tmi_out[i] = [torch.cat(tmi_l,0).to( cdev, non_blocking=True) for tmi_l in tmi] + tmi_out += [ [torch.cat(tmi_l).to( cdev, non_blocking=True) for tmi_l in tmi] ] self.tokens_masked_idx = tmi_out # learnable class token (cannot be done in the data loader since this is running in parallel) @@ -851,8 +850,8 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : coords_b = [] for bidx in range(batch_size): dates = self.sources_info[bidx][0] - lats = self.sources_info[bidx][1].numpy() - lons = self.sources_info[bidx][2].numpy() + lats = self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] # target etc are aliasing targets_b which simplifies bookkeeping below if is_predicted : @@ -865,36 +864,33 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : normalizer = self.model.normalizer( fidx, vidx) y, m = dates[0].year, dates[0].month - #breakpoint() - # print("-----lats source---") - # print(lats) - # print("-----lons source---") - # print(lons) sources_b[bidx,vidx] = normalizer.denormalize( y, m, sources_b[bidx,vidx], [lats, lons]) if is_predicted : + # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] grid = np.flip(np.array( np.meshgrid( self.sources_info[bidx][2], self.sources_info[bidx][1])), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + # recover time dimension since idx assumes the full space-time cube grid = torch.from_numpy( np.array( np.broadcast_to( grid, shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) - #breakpoint() + grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) + grid_lons_toked = tokenize( grid[0], token_size).flatten( 0, 2) + + idx_loc = idx - np.prod(num_tokens) * bidx #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] - lats_mskd = np.array([np.unique(t) for t in tokenize( grid[0], token_size).flatten( 0, 2)[ idx ].numpy()]) - lons_mskd = np.array([np.unique(t) for t in tokenize( grid[1], token_size).flatten( 0, 2)[ idx ].numpy()]) + lats_mskd = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + lons_mskd = np.array([np.unique(t) for t in grid_lons_toked[ idx_loc ].numpy()]) #time: idx ranges from 0->863 12x6x12 - t_idx = (np.floor(idx / (num_tokens[1]*num_tokens[2])) * token_size[0]).int() - t_idx = np.array([np.arange(t, t + token_size[0]) for t in t_idx]) #create range from t_idx-2 to t_idx + t_idx = (idx_loc // (num_tokens[1]*num_tokens[2])) * token_size[0] + #create range from t_idx-2 to t_idx + t_idx = np.array([np.arange(t, t + token_size[0]) for t in t_idx]) dates_mskd = self.sources_info[bidx][0][t_idx] for ii,(t,p,e,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], lats_mskd, lons_mskd)) : - # print("----la ----") - # print(la) - # print("----lo ----") - # print(lo) targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) preds_mu_b[vidx][bidx][ii] = normalizer.denormalize( y, m, p, [la, lo]) preds_ens_b[vidx][bidx][ii] = normalizer.denormalize( y, m, e, [la, lo]) From e0f72d1576a2622c6ca83078080eaf81b5959552 Mon Sep 17 00:00:00 2001 From: iluise Date: Tue, 9 Apr 2024 12:26:37 +0200 Subject: [PATCH 21/66] validated version against main --- atmorep/core/atmorep_model.py | 1 + atmorep/core/evaluate.py | 24 ++++++++++---------- atmorep/core/evaluator.py | 25 +++++++++++++-------- atmorep/core/trainer.py | 12 +++++----- atmorep/datasets/multifield_data_sampler.py | 4 ++-- atmorep/datasets/normalizer_global.py | 2 ++ atmorep/datasets/normalizer_local.py | 1 + 7 files changed, 40 insertions(+), 29 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index f8c56ab..d301607 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -450,6 +450,7 @@ def forward( self, xin) : # embedding cf = self.cf + fields_embed = self.get_fields_embed(xin) # attention maps (if requested) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 76ba9ee..2a25d7d 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -19,14 +19,14 @@ if __name__ == '__main__': # models for individual fields - # model_id = '4nvwbetz' # vorticity - # model_id = 'oxpycr7w' # divergence + #model_id = '4nvwbetz' # vorticity + #model_id = 'oxpycr7w' # divergence # model_id = '1565pb1f' # specific_humidity # model_id = '3kdutwqb' # total precip - # model_id = 'dys79lgw' # velocity_u - # model_id = '22j6gysw' # velocity_v + #model_id = 'dys79lgw' # velocity_u + model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z - model_id = '3qou60es' # temperature (also 2147fkco) + #model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence @@ -43,17 +43,17 @@ # BERT masked token model #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} - #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} + mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} # BERT forecast mode - #mode, options = 'forecast', {'forecast_num_tokens' : 1} #, 'fields[0][2]' : [123, 137], 'attention' : False } + #mode, options = 'forecast', {'forecast_num_tokens' : 1, 'fields[0][2]' : [123], 'attention' : False } # BERT forecast with patching to obtain global forecast - mode, options = 'global_forecast', { 'fields[0][2]' : [123], - 'dates' : [[2021, 2, 10, 12]], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 1, - 'attention' : False } + # mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], + # 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], + # 'token_overlap' : [0, 0], + # 'forecast_num_tokens' : 1, + # 'attention' : False } now = time.time() Evaluator.evaluate( mode, model_id, options) print("time", time.time() - now) \ No newline at end of file diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 08aac24..dc62403 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -91,17 +91,17 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : cf.num_loader_workers = cf.loader_num_workers cf.data_dir = './data/' - cf.rng_seed = 0 #None + cf.rng_seed = None + #backward compatibility if not hasattr( cf, 'n_size'): - cf.n_size = [36, 0.25*9*6, 0.25*9*12] + #cf.n_size = [36, 0.25*9*6, 0.25*9*12] + cf.n_size = [36, 0.25*27*2, 0.25*27*4] if not hasattr(cf, 'num_samples_per_epoch'): cf.num_samples_per_epoch = 1024 - if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 196 #128 if not hasattr(cf, 'with_mixed_precision'): cf.with_mixed_precision = False - cf.batch_size_start = 14 + # cf.batch_size_start = 14 func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) @@ -112,7 +112,8 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : cf.lat_sampling_weighted = False cf.BERT_strategy = 'BERT' cf.log_test_num_ranks = 4 - + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) @@ -125,7 +126,8 @@ def forecast( cf, model_id, model_epoch, devices, args = {}) : cf.BERT_strategy = 'forecast' cf.log_test_num_ranks = 4 cf.forecast_num_tokens = 1 # will be overwritten when user specified - + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) @@ -138,7 +140,9 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : cf.batch_size_test = 24 cf.num_loader_workers = 1 cf.log_test_num_ranks = 1 - + cf.batch_size_start = 14 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 196 Evaluator.parse_args( cf, args) dates = args['dates'] @@ -161,7 +165,10 @@ def global_forecast_range( cf, model_id, model_epoch, devices, args = {}) : cf.batch_size_test = 24 cf.num_loader_workers = 1 cf.log_test_num_ranks = 1 - + cf.batch_size_start = 14 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 196 + Evaluator.parse_args( cf, args) if 0 == cf.par_rank : diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index a567501..7c268da 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -245,9 +245,9 @@ def train( self, epoch): self.optimizer.zero_grad() for batch_idx in range( model.len( NetMode.train)) : - + batch_data = self.model.next() - + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=cf.with_mixed_precision): batch_data = self.prepare_batch( batch_data) preds, _ = self.model_ddp( batch_data) @@ -386,7 +386,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory (sources, token_infos, targets, tmis_list) = batch_data[0] - + # breakpoint() # targets #TO-DO: implement target # if len(batch_data[1]) > 0 : @@ -617,12 +617,12 @@ def prepare_batch( self, xin) : cf = self.cf devs = self.devices - + # unpack loader output # xin[0] since BERT does not have targets (sources, token_infos, targets, fields_tokens_masked_idx_list) = xin[0] (self.sources_idxs, self.sources_info) = xin[2] - + # network input batch_data = [ ( sources[i].to( devs[ cf.fields[i][1][3] ], non_blocking=True), self.tok_infos_trans(token_infos[i]).to( self.devices[0], non_blocking=True)) @@ -652,7 +652,7 @@ def prepare_batch( self, xin) : assert len(cf.fields[ifield][2]) == 1 tmidx = self.tokens_masked_idx[ifield][0] source[ tmidx ] = self.model.net.masks[ifield].to( source.device) - +# breakpoint() return batch_data ################################################### diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index b2faceb..08186e1 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -209,7 +209,7 @@ def __iter__(self): # extract data, normalize and tokenize cdata = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) - # breakpoint() + #breakpoint() cdata = nf( year, month, cdata, (lats[lat_ran], lons[lon_ran]) ) source_data = tokenize( torch.from_numpy( cdata), tok_size ) @@ -238,7 +238,7 @@ def __iter__(self): # TODO: implement (only required when prediction target comes from different data stream) targets, target_info = None, None target_idxs = None - + #breakpoint() yield ( sources, targets, (source_idxs, sources_infos), (target_idxs, target_info)) ################################################### diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py index 6afe576..4336b6b 100644 --- a/atmorep/datasets/normalizer_global.py +++ b/atmorep/datasets/normalizer_global.py @@ -18,6 +18,7 @@ import atmorep.config.config as config import pdb + class NormalizerGlobal() : def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : @@ -34,6 +35,7 @@ def normalize( self, year, month, data, coords = None) : #breakpoint() corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), self.corr_data[:,1] == float(month))) , 2:].flatten() + #breakpoint() data_temp = (data - corr_data_ym[0]) / corr_data_ym[1] # print(data_temp.mean(), data_temp.std()) return (data - corr_data_ym[0]) / corr_data_ym[1] diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py index 3669e7c..63cc62d 100644 --- a/atmorep/datasets/normalizer_local.py +++ b/atmorep/datasets/normalizer_local.py @@ -61,6 +61,7 @@ def normalize( self, year, month, data, coords) : print( f'var == 0 :: ym : {year} / {month}') assert False # print("before", data.mean(), data.std()) + #breakpoint() if len(data.shape) > 2 : for i in range( data.shape[0]) : data[i] = (data[i] - mean) / var From 6ba652172f20ac76a0e2bd5da232c150b7dd004e Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 9 Apr 2024 13:38:04 +0200 Subject: [PATCH 22/66] - Fixed handling of shuffle() - Cleaned up code in various places (e.g. removed load_data from atmorep_model which was no longer needed). --- atmorep/core/atmorep_model.py | 58 +++++++-------------- atmorep/core/evaluator.py | 1 - atmorep/datasets/data_writer.py | 4 +- atmorep/datasets/multifield_data_sampler.py | 41 +++++---------- 4 files changed, 34 insertions(+), 70 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index d301607..a835d55 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -54,35 +54,6 @@ def __init__( self, net) : self.rng_seed = net.cf.rng_seed if not self.rng_seed : self.rng_seed = int(torch.randint( 100000000, (1,))) - - ################################################### - def load_data( self, mode : NetMode, batch_size = -1, num_loader_workers = -1) : - '''Load data''' - - cf = self.net.cf - - if batch_size < 0 : - batch_size = cf.batch_size_max - if num_loader_workers < 0 : - num_loader_workers = cf.num_loader_workers - - if mode == NetMode.train : - self.data_loader_train = self._load_data( self.dataset_train, batch_size, num_loader_workers) - elif mode == NetMode.test : - batch_size = cf.batch_size_test - self.data_loader_test = self._load_data( self.dataset_test, batch_size, num_loader_workers) - else : - assert False - - ################################################### - def _load_data( self, dataset, batch_size, num_loader_workers) : - '''Private implementation for load''' - - loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False, - 'num_workers': num_loader_workers, 'pin_memory': True} - data_loader = torch.utils.data.DataLoader( dataset, **loader_params, sampler = None) - - return data_loader ################################################### def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_workers = -1) : @@ -92,9 +63,8 @@ def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_worke batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test dataset = self.dataset_train if mode == NetMode.train else self.dataset_test - print("ueh") dataset.set_data( times_pos, batch_size) - print("probably I should not be here..") + self._set_data( dataset, mode, batch_size, num_loader_workers) ################################################### @@ -154,7 +124,6 @@ def normalizer( self, field, vl_idx) : elif isinstance( field, int) : normalizer = self.dataset_train.normalizers[field][vl_idx] -# normalizer = self.dataset_train.datasets[field][vl_idx].normalizer else : assert False, 'invalid argument type (has to be index to cf.fields or field name)' @@ -165,12 +134,12 @@ def normalizer( self, field, vl_idx) : def mode( self, mode : NetMode) : if mode == NetMode.train : - # self.data_loader_iter = iter(self.data_loader_train) - self.data_loader_iter = iter(self.dataset_train) + self.data_loader_iter = iter(self.data_loader_train) + # self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : - # self.data_loader_iter = iter(self.data_loader_test) - self.data_loader_iter = iter(self.dataset_test) + self.data_loader_iter = iter(self.data_loader_test) + # self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -196,8 +165,8 @@ def forward( self, xin) : return pred ################################################### - def get_attention( self, xin): #, field_idx) : - attn = self.net.get_attention( xin) #, field_idx) + def get_attention( self, xin) : + attn = self.net.get_attention( xin) return attn ################################################### @@ -211,14 +180,23 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.pre_batch_targets = pre_batch_targets cf = self.net.cf + loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False, + 'num_workers': cf.num_loader_workers, 'pin_memory': True} + self.dataset_train = MultifieldDataSampler( cf.fields, cf.years_train, cf.batch_size_start, - pre_batch, cf.n_size, cf.num_samples_per_epoch ) - + pre_batch, cf.n_size, cf.num_samples_per_epoch, + with_shuffle = cf.BERT_strategy != 'global_forecast' ) + self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, + sampler = None) + self.dataset_test = MultifieldDataSampler( cf.fields, cf.years_test, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_validate, + with_shuffle = cf.BERT_strategy != 'global_forecast', with_source_idxs = True ) + self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params, + sampler = None) return self diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index dc62403..f37e444 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -57,7 +57,6 @@ def run( cf, model_id, model_epoch, devices) : # set/over-write options as desired evaluator = Evaluator.load( cf, model_id, model_epoch, devices) - evaluator.model.load_data( NetMode.test) if 0 == cf.par_rank : cf.print() cf.write_json( wandb) diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index 5e3365d..cab73ce 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -26,8 +26,8 @@ def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): ds_batch_item.create_dataset( 'data', data=data) ds_batch_item.create_dataset( 'ml', data=levels) ds_batch_item.create_dataset( 'datetime', data=coords[0].astype(np.datetime64)) - ds_batch_item.create_dataset( 'lat', data=coords[1].astype(np.float32)) - ds_batch_item.create_dataset( 'lon', data=coords[2].astype(np.float32)) + ds_batch_item.create_dataset( 'lat', data=np.array(coords[1]).astype(np.float32)) + ds_batch_item.create_dataset( 'lon', data=np.array(coords[2]).astype(np.float32)) return ds_batch_item #################################################################################################### diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 08186e1..6a87809 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -16,19 +16,10 @@ import torch import numpy as np -import math -import itertools -import code -# code.interact(local=locals()) import zarr import pandas as pd -import pdb -import code -from atmorep.utils.utils import days_until_month_in_year -from atmorep.utils.utils import days_in_month from datetime import datetime - -import atmorep.config.config as config +import time from atmorep.datasets.normalizer_global import NormalizerGlobal from atmorep.datasets.normalizer_local import NormalizerLocal @@ -38,8 +29,8 @@ class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### - def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_per_epoch, - rng_seed = None, time_sampling = 1, with_source_idxs = False, + def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_per_epoch, + with_shuffle = False, time_sampling = 1, with_source_idxs = False, fields_targets = None, pre_batch_targets = None ) : ''' Data set for single dynamic field at an arbitrary number of vertical levels @@ -53,6 +44,7 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.n_size = n_size self.num_samples = num_samples_per_epoch self.with_source_idxs = with_source_idxs + self.with_shuffle = with_shuffle self.pre_batch = pre_batch @@ -91,7 +83,7 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.range_lon = np.array( self.lons[ [0,-1] ]) self.res = np.zeros( 2) - self.res[0] = self.ds.attrs['resol'][0] + self.res[0] = [0] self.res[1] = self.ds.attrs['resol'][1] # ensure neighborhood does not exceed domain (either at pole or for finite domains) @@ -100,13 +92,6 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe if self.ds_global < 1.: self.range_lon += np.array([n_size[2]/2., -n_size[2]/2.]) - # ensure all data loaders use same rng_seed and hence generate consistent data - if not rng_seed : - rng_seed = np.random.randint( 0, 100000, 1)[0] - self.rng = np.random.default_rng( rng_seed) - else: - self.rng = rng_seed - # data normalizers self.normalizers = [] for _, field_info in enumerate(fields) : @@ -131,7 +116,12 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe ################################################### def shuffle( self) : - rng = self.rng + worker_info = torch.utils.data.get_worker_info() + rng_seed = None + if worker_info is not None : + rng_seed = int(time.time()) // (worker_info.id+1) + worker_info.id + + rng = np.random.default_rng( rng_seed) self.idxs_perm_t = rng.permutation( self.idxs_years)[ : self.num_samples] lats = rng.random(self.num_samples) * (self.range_lat[1] - self.range_lat[0]) +self.range_lat[0] @@ -147,13 +137,10 @@ def shuffle( self) : ################################################### def __iter__(self): - # TODO: if we keep this then we should remove the rng_seed argument for the constuctor - #self.rng = np.random.default_rng() - #TODO: move shuffle outside iter to avoid param overwriting in global_forecast!!! NB. BERT does not work without shuffle!! - self.shuffle() + if self.with_shuffle : + self.shuffle() lats, lons = self.lats, self.lons - #fields_idxs, levels_idxs = self.fields_idxs, self.levels_idxs ts, n_size = self.time_sampling, self.n_size ns_2 = np.array(self.n_size) / 2. res = self.res @@ -192,7 +179,7 @@ def __iter__(self): year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month for ifield, field_info in enumerate(self.fields): - source_lvl, source_info_lvl, tok_info_lvl = [], [], [] + source_lvl, tok_info_lvl = [], [] tok_size = field_info[4] for ilevel, vl in enumerate(field_info[2]): From 296595a6733cc1e89d8010aaf55bbde04e9e3e65 Mon Sep 17 00:00:00 2001 From: iluise Date: Tue, 9 Apr 2024 17:56:15 +0200 Subject: [PATCH 23/66] fix bert_strategy --- atmorep/config/config.py | 4 +- atmorep/core/atmorep_model.py | 8 +- atmorep/core/evaluate.py | 16 ++-- atmorep/core/evaluator.py | 6 +- atmorep/core/trainer.py | 24 ++---- atmorep/datasets/data_writer.py | 85 +++++---------------- atmorep/datasets/multifield_data_sampler.py | 27 +------ atmorep/datasets/normalizer_global.py | 11 +-- atmorep/datasets/normalizer_local.py | 14 +--- atmorep/training/bert.py | 2 + 10 files changed, 53 insertions(+), 144 deletions(-) diff --git a/atmorep/config/config.py b/atmorep/config/config.py index 03c03f7..087dc33 100644 --- a/atmorep/config/config.py +++ b/atmorep/config/config.py @@ -3,8 +3,8 @@ fpath = os.path.dirname(os.path.realpath(__file__)) -year_base = 1979 -year_last = 2022 +# year_base = 1979 +# year_last = 2022 path_models = Path( fpath, '../../models/') path_results = Path( fpath, '../../results') diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index a835d55..97f03de 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -135,11 +135,11 @@ def mode( self, mode : NetMode) : if mode == NetMode.train : self.data_loader_iter = iter(self.data_loader_train) - # self.data_loader_iter = iter(self.dataset_train) + #self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : self.data_loader_iter = iter(self.data_loader_test) - # self.data_loader_iter = iter(self.dataset_test) + #self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -186,14 +186,14 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.dataset_train = MultifieldDataSampler( cf.fields, cf.years_train, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_per_epoch, - with_shuffle = cf.BERT_strategy != 'global_forecast' ) + with_shuffle = (cf.BERT_strategy != 'global_forecast') ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) self.dataset_test = MultifieldDataSampler( cf.fields, cf.years_test, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_validate, - with_shuffle = cf.BERT_strategy != 'global_forecast', + with_shuffle = (cf.BERT_strategy != 'global_forecast'), with_source_idxs = True ) self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params, sampler = None) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 2a25d7d..7738d24 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -24,9 +24,9 @@ # model_id = '1565pb1f' # specific_humidity # model_id = '3kdutwqb' # total precip #model_id = 'dys79lgw' # velocity_u - model_id = '22j6gysw' # velocity_v + #model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z - #model_id = '3qou60es' # temperature (also 2147fkco) + model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence @@ -43,17 +43,17 @@ # BERT masked token model #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} - mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} + #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} # BERT forecast mode #mode, options = 'forecast', {'forecast_num_tokens' : 1, 'fields[0][2]' : [123], 'attention' : False } # BERT forecast with patching to obtain global forecast - # mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], - # 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], - # 'token_overlap' : [0, 0], - # 'forecast_num_tokens' : 1, - # 'attention' : False } + mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], + 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], + 'token_overlap' : [0, 0], + 'forecast_num_tokens' : 1, + 'attention' : False } now = time.time() Evaluator.evaluate( mode, model_id, options) print("time", time.time() - now) \ No newline at end of file diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index f37e444..537f079 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -112,7 +112,7 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : cf.BERT_strategy = 'BERT' cf.log_test_num_ranks = 4 if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 128 + cf.num_samples_validate = 1472 #128 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) @@ -135,7 +135,7 @@ def forecast( cf, model_id, model_epoch, devices, args = {}) : @staticmethod def global_forecast( cf, model_id, model_epoch, devices, args = {}) : - cf.BERT_strategy = 'forecast' + cf.BERT_strategy = 'global_forecast' cf.batch_size_test = 24 cf.num_loader_workers = 1 cf.log_test_num_ranks = 1 @@ -158,7 +158,7 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : def global_forecast_range( cf, model_id, model_epoch, devices, args = {}) : cf.forecast_num_tokens = 2 - cf.BERT_strategy = 'forecast' + cf.BERT_strategy = 'global_forecast' cf.token_overlap = [2, 6] cf.batch_size_test = 24 diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 7c268da..3a65f9e 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -692,9 +692,9 @@ def log_validate( self, epoch, bidx, log_sources, log_preds) : if not hasattr( self.cf, 'wandb_id') : return - if 'forecast' == self.cf.BERT_strategy : + if 'forecast' in self.cf.BERT_strategy : self.log_validate_forecast( epoch, bidx, log_sources, log_preds) - elif 'BERT' == self.cf.BERT_strategy : + elif 'BERT' in self.cf.BERT_strategy : self.log_validate_BERT( epoch, bidx, log_sources, log_preds) else : assert False @@ -719,11 +719,7 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : if hasattr( cf, 'forecast_num_tokens') : forecast_num_tokens = cf.forecast_num_tokens - # TODO: check that last token matches first one - # process input fields - sources_coords = [] - targets_coords = [] - + coords = [] for fidx, field_info in enumerate(cf.fields) : # reshape from tokens to contiguous physical field num_levels = len(field_info[2]) @@ -746,14 +742,12 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : source[bidx,vidx] = denormalize( dates[0].year, dates[0].month, source[bidx,vidx], [lats, lons]) target[bidx,vidx] = denormalize( dates_t[0].year, dates_t[0].month, target[bidx,vidx], [lats, lons]) - coords_b += [[dates, 90.-lats, lons]] - targ_coords_b += [[dates_t, 90.-lats, lons]] + coords_b += [[dates, 90.-lats, lons, dates_t]] # append sources_out.append( [field_info[0], source]) targets_out.append( [field_info[0], target]) - sources_coords.append(coords_b) - targets_coords.append(targ_coords_b) + coords.append(coords_b) # process predicted fields for fidx, fn in enumerate(cf.fields_prediction) : @@ -789,9 +783,9 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : levels = np.array(cf.fields[0][2]) write_forecast( cf.wandb_id, epoch, batch_idx, - levels, sources_out, sources_coords , - targets_out, targets_coords, #[dates_targets, lats, lons], - preds_out, ensembles_out ) + levels, sources_out, + targets_out, preds_out, + ensembles_out, coords) ################################################### @@ -867,8 +861,6 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : sources_b[bidx,vidx] = normalizer.denormalize( y, m, sources_b[bidx,vidx], [lats, lons]) if is_predicted : - - # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] grid = np.flip(np.array( np.meshgrid( self.sources_info[bidx][2], self.sources_info[bidx][1])), axis = 0) #flip to have lat on pos 0 and lon on pos 1 diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index cab73ce..941bfa0 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -31,16 +31,16 @@ def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): return ds_batch_item #################################################################################################### -def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, - targets, targets_coords, - preds, ensembles, - zarr_store_type = 'ZipStore' ) : +def write_forecast( model_id, epoch, batch_idx, levels, sources, + targets, preds, ensembles, coords, + zarr_store_type = 'ZipStore' ) : ''' sources : num_fields x [field name , data] targets : preds, ensemble share coords with targets ''' - + sources_coords = [[c[:3] for c in coord_field ] for coord_field in coords] + targets_coords = [[[c[-1], c[1], c[2]] for c in coord_field ] for coord_field in coords] fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' zarr_store = getattr( zarr, zarr_store_type) @@ -53,13 +53,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, sources_coords[fidx][bidx]) #[t[bidx] for t in sources_coords] ) - # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - # ds_batch_item.create_dataset( 'ml', data=levels) - # ds_batch_item.create_dataset( 'datetime', data=sources_coords[0][bidx]) - # ds_batch_item.create_dataset( 'lat', data=sources_coords[1][bidx]) - # ds_batch_item.create_dataset( 'lon', data=sources_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, sources_coords[fidx][bidx]) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -69,13 +63,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) #[t[bidx] for t in targets_coords] ) - # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - # ds_batch_item.create_dataset( 'ml', data=levels) - # ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - # ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - # ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -85,13 +73,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) #[t[bidx] for t in targets_coords] ) - # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - # ds_batch_item.create_dataset( 'ml', data=levels) - # ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - # ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - # ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -101,31 +83,23 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) # [t[bidx] for t in targets_coords] ) - # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - # ds_batch_item.create_dataset( 'ml', data=levels) - # ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - # ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - # ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_ens.close() #################################################################################################### -def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, - targets, # targets_coords, - preds, ensembles, coords, - zarr_store_type = 'ZipStore' ) : +def write_BERT( model_id, epoch, batch_idx, levels, sources, + targets, preds, ensembles, coords, + zarr_store_type = 'ZipStore' ) : + ''' sources : num_fields x [field name , data] targets : preds, ensemble share coords with targets ''' - # breakpoint() sources_coords = [[c[:3] for c in coord_field ] for coord_field in coords] targets_coords = [[c[3:] for c in coord_field ] for coord_field in coords] - # fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch}.zarr' fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' zarr_store = getattr( zarr, zarr_store_type) @@ -138,12 +112,6 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] ) - # ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - # ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - # ds_batch_item.create_dataset( 'ml', data=levels[fidx]) - # ds_batch_item.create_dataset( 'datetime', data=sources_coords[fidx][bidx][0]) - # ds_batch_item.create_dataset( 'lat', data=sources_coords[fidx][bidx][1]) - # ds_batch_item.create_dataset( 'lon', data=sources_coords[fidx][bidx][2]) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -158,12 +126,6 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, ds_target_b = ds_field.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : ds_target_b_l = write_item(ds_target_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) - # ds_target_b_l = ds_target_b.require_group( f'ml={levels[fidx][vidx]}') - # ds_target_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - # ds_target_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - # ds_target_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) - # ds_target_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) - # ds_target_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -179,12 +141,6 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, for vidx in range(len(levels[fidx])) : ds_pred_b_l = write_item(ds_pred_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) - # ds_pred_b_l = ds_pred_b.create_group( f'ml={levels[fidx][vidx]}') - # ds_pred_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - # ds_pred_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - # ds_pred_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) - # ds_pred_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) - # ds_pred_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -200,17 +156,10 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, #sources_coords, for vidx in range(len(levels[fidx])) : ds_ens_b_l = write_item(ds_ens_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) - - # ds_ens_b_l = ds_ens_b.create_group( f'ml={levels[fidx][vidx]}') - # ds_ens_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - # ds_ens_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - # ds_ens_b_l.create_dataset( 'datetime', data=targets_coords[fidx][bidx][0][vidx]) - # ds_ens_b_l.create_dataset( 'lat', data=targets_coords[fidx][bidx][1][vidx]) - # ds_ens_b_l.create_dataset( 'lon', data=targets_coords[fidx][bidx][2][vidx]) store_ens.close() #################################################################################################### -def write_attention(model_id, epoch, batch_idx, levels, attn, attn_coords, zarr_store_type = 'ZipStore' ) : +def write_attention(model_id, epoch, batch_idx, levels, attn, coords, zarr_store_type = 'ZipStore' ) : fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' zarr_store = getattr( zarr, zarr_store_type) @@ -224,9 +173,9 @@ def write_attention(model_id, epoch, batch_idx, levels, attn, attn_coords, zarr_ for lidx, atts_f_l in enumerate(atts_f[1]) : # layer in the network ds_f_l = ds_field_b.require_group( f'layer={lidx:05d}') ds_f_l.create_dataset( 'ml', data=levels[fidx]) - ds_f_l.create_dataset( 'datetime', data=attn_coords[0][fidx]) - ds_f_l.create_dataset( 'lat', data=attn_coords[1][fidx]) - ds_f_l.create_dataset( 'lon', data=attn_coords[2][fidx]) + ds_f_l.create_dataset( 'datetime', data=coords[0][fidx]) + ds_f_l.create_dataset( 'lat', data=coords[1][fidx]) + ds_f_l.create_dataset( 'lon', data=coords[2][fidx]) ds_f_l_h = ds_f_l.require_group('heads') for hidx, atts_f_l_head in enumerate(atts_f_l) : # number of attention head if atts_f_l_head != None : diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 6a87809..399757e 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -25,7 +25,6 @@ from atmorep.datasets.normalizer_local import NormalizerLocal from atmorep.utils.utils import tokenize - class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### @@ -48,20 +47,11 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.pre_batch = pre_batch - # create (source) fields - # config.path_data - fname_source = '/p/scratch/atmo-rep/era5_res0025_1979.zarr' - fname_source = '/p/scratch/atmo-rep/era5_res0025_2021.zarr' fname_source = '/p/scratch/atmo-rep/data/era5_1deg/era5_res0025_2021_final.zarr' - # fname_source = '/p/scratch/atmo-rep/era5_res0100_2021_t5.zarr' self.ds = zarr.open( fname_source) self.ds_global = self.ds.attrs['is_global'] self.ds_len = self.ds['data'].shape[0] - # sanity checking - # assert self.ds['data'].shape[0] == self.ds['time'].shape[0] - # assert self.ds_len >= num_samples_per_epoch - self.lats = np.array( self.ds['lats']) self.lons = np.array( self.ds['lons']) @@ -81,11 +71,8 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.time_sampling = time_sampling self.range_lat = np.array( self.lats[ [0,-1] ]) self.range_lon = np.array( self.lons[ [0,-1] ]) + self.res = np.array(self.ds.attrs['resol']) - self.res = np.zeros( 2) - self.res[0] = [0] - self.res[1] = self.ds.attrs['resol'][1] - # ensure neighborhood does not exceed domain (either at pole or for finite domains) self.range_lat += np.array([n_size[1] / 2., -n_size[1] / 2.]) # lon: no change for periodic case @@ -100,17 +87,9 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe ner = NormalizerGlobal if corr_type == 'global' else NormalizerLocal for vl in field_info[2]: data_type = 'data_sfc' if vl == 0 else 'data' #surface field - self.normalizers[-1] += [ ner( field_info, vl, - np.array(self.ds[data_type].shape)[[0,-2,-1]]) ] + self.normalizers[-1] += [ ner( field_info, vl ) ] # extract indices for selected years self.times = pd.DatetimeIndex( self.ds['time']) - # idxs = np.zeros( self.ds['time'].shape[0], dtype=np.bool_) - # self.idxs_years = np.array( []) - # for year in years : - # idxs = np.where( (self.times >= f'{year}-1-1') & (self.times <= f'{year}-12-31'))[0] - # assert idxs.shape[0] > 0, f'Requested year is not in dataset {fname_source}. Aborting.' - # self.idxs_years = np.append( self.idxs_years, idxs[::self.time_sampling]) - # TODO, TODO, TODO: self.idxs_years = np.arange( self.ds_len) ################################################### @@ -196,7 +175,6 @@ def __iter__(self): # extract data, normalize and tokenize cdata = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) - #breakpoint() cdata = nf( year, month, cdata, (lats[lat_ran], lons[lon_ran]) ) source_data = tokenize( torch.from_numpy( cdata), tok_size ) @@ -225,7 +203,6 @@ def __iter__(self): # TODO: implement (only required when prediction target comes from different data stream) targets, target_info = None, None target_idxs = None - #breakpoint() yield ( sources, targets, (source_idxs, sources_infos), (target_idxs, target_info)) ################################################### diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py index 4336b6b..622f672 100644 --- a/atmorep/datasets/normalizer_global.py +++ b/atmorep/datasets/normalizer_global.py @@ -21,9 +21,8 @@ class NormalizerGlobal() : - def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : + def __init__(self, field_info, vlevel, data_type = 'era5', level_type = 'ml') : - # TODO: use path from config and pathlib.Path() fname_base = '{}/{}/normalization/{}/global_normalization_mean_var_{}_{}{}.bin' fn = field_info[0] @@ -32,18 +31,14 @@ def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_typ def normalize( self, year, month, data, coords = None) : - #breakpoint() corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), - self.corr_data[:,1] == float(month))) , 2:].flatten() - #breakpoint() + self.corr_data[:,1] == float(month))) , 2:].flatten() data_temp = (data - corr_data_ym[0]) / corr_data_ym[1] - # print(data_temp.mean(), data_temp.std()) return (data - corr_data_ym[0]) / corr_data_ym[1] def denormalize( self, year, month, data, coords = None) : corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), self.corr_data[:,1] == float(month))) , 2:].flatten() - data_temp = (data * corr_data_ym[1]) + corr_data_ym[0] - #print("after denorm", data_temp.mean(), data_temp.std()) + data_temp = (data * corr_data_ym[1]) + corr_data_ym[0] return (data * corr_data_ym[1]) + corr_data_ym[0] \ No newline at end of file diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py index 63cc62d..743cf63 100644 --- a/atmorep/datasets/normalizer_local.py +++ b/atmorep/datasets/normalizer_local.py @@ -17,12 +17,11 @@ import code import numpy as np import xarray as xr -import pdb import atmorep.config.config as config class NormalizerLocal() : - def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : + def __init__(self, field_info, vlevel, data_type = 'era5', level_type = 'ml') : fname_base = './data/{}/normalization/{}/normalization_mean_var_{}_y{}_m{:02d}_{}{}.bin' self.year_base = config.datasets[data_type]['extent'][0][0] @@ -41,8 +40,8 @@ def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_typ year, month, level_type, vlevel) ns_lat = int( (lat_max-lat_min) / res + 1) ns_lon = int( (lon_max-lon_min) / res + (0 if is_global else 1) ) - # TODO: remove file_shape (ns_lat, ns_lon contains same information) - x = np.fromfile( corr_fname, dtype=np.float32).reshape( (file_shape[1], file_shape[2], 2)) + + x = np.fromfile( corr_fname, dtype=np.float32).reshape( (ns_lat, ns_lon, 2)) # TODO, TODO, TODO: remove once recomputed if 'cerra' == data_type : x[:,:,0] = 340. @@ -60,28 +59,23 @@ def normalize( self, year, month, data, coords) : if (var == 0.).all() : print( f'var == 0 :: ym : {year} / {month}') assert False - # print("before", data.mean(), data.std()) - #breakpoint() + if len(data.shape) > 2 : for i in range( data.shape[0]) : data[i] = (data[i] - mean) / var else : data = (data - mean) / var - # print("after", data.mean(), data.std()) return data def denormalize( self, year, month, data, coords) : - corr_data_ym = self.corr_data[ (year - self.year_base) * 12 + (month-1) ] mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values - #print("before", data.mean(), data.std()) if len(data.shape) > 2 : for i in range( data.shape[0]) : data[i] = (data[i] * var) + mean else : data = (data * var) + mean - #print("after", data.mean(), data.std()) return data \ No newline at end of file diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 0f47223..0f5b4f5 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -34,6 +34,8 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, if BERT_strategy == 'BERT' : bert_f = prepare_batch_BERT_field + elif BERT_strategy == 'global_forecast' : + bert_f = prepare_batch_BERT_forecast_field elif BERT_strategy == 'forecast' : bert_f = prepare_batch_BERT_forecast_field elif BERT_strategy == 'temporal_interpolation' : From f28e1c94293767aa705ebfcd49e459ba8cbcfd27 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 11 Apr 2024 15:50:33 +0200 Subject: [PATCH 24/66] fix temporal interpol, fix overlap, remove evaluate --- atmorep/core/evaluate.py | 19 ++-- atmorep/core/evaluator.py | 14 +-- atmorep/core/train.py | 6 +- atmorep/core/trainer.py | 97 ++------------------- atmorep/datasets/multifield_data_sampler.py | 11 +-- atmorep/datasets/normalizer_global.py | 2 - atmorep/datasets/normalizer_local.py | 5 +- atmorep/training/bert.py | 93 +------------------- 8 files changed, 34 insertions(+), 213 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 7738d24..d785e86 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -26,11 +26,11 @@ #model_id = 'dys79lgw' # velocity_u #model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z - model_id = '3qou60es' # temperature (also 2147fkco) + #model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence - #model_id = '1jh2qvrx' # multiformer, velocity + model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity # model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting @@ -46,14 +46,17 @@ #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} # BERT forecast mode - #mode, options = 'forecast', {'forecast_num_tokens' : 1, 'fields[0][2]' : [123], 'attention' : False } + mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } + #temporal interpolation + #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'attention' : False } + # BERT forecast with patching to obtain global forecast - mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], - 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 1, - 'attention' : False } + # mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], + # 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], + # 'token_overlap' : [2, 6], + # 'forecast_num_tokens' : 2, + # 'attention' : False } now = time.time() Evaluator.evaluate( mode, model_id, options) print("time", time.time() - now) \ No newline at end of file diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 537f079..c3588e2 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -94,13 +94,12 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : #backward compatibility if not hasattr( cf, 'n_size'): - #cf.n_size = [36, 0.25*9*6, 0.25*9*12] - cf.n_size = [36, 0.25*27*2, 0.25*27*4] + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + #cf.n_size = [36, 0.25*27*2, 0.25*27*4] if not hasattr(cf, 'num_samples_per_epoch'): cf.num_samples_per_epoch = 1024 if not hasattr(cf, 'with_mixed_precision'): cf.with_mixed_precision = False - # cf.batch_size_start = 14 func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) @@ -112,7 +111,7 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : cf.BERT_strategy = 'BERT' cf.log_test_num_ranks = 4 if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 1472 #128 + cf.num_samples_validate = 128 #1472 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) @@ -145,7 +144,6 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : Evaluator.parse_args( cf, args) dates = args['dates'] - print("inside global forecast") evaluator = Evaluator.load( cf, model_id, model_epoch, devices) evaluator.model.set_global( NetMode.test, np.array( dates)) if 0 == cf.par_rank : @@ -159,7 +157,7 @@ def global_forecast_range( cf, model_id, model_epoch, devices, args = {}) : cf.forecast_num_tokens = 2 cf.BERT_strategy = 'global_forecast' - cf.token_overlap = [2, 6] + cf.token_overlap = [0, 0] cf.batch_size_test = 24 cf.num_loader_workers = 1 @@ -195,7 +193,9 @@ def temporal_interpolation( cf, model_id, model_epoch, devices, args = {}) : # set/over-write options cf.BERT_strategy = 'temporal_interpolation' cf.log_test_num_ranks = 4 - + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 4657bde..91cf03b 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -185,7 +185,7 @@ def train() : cf.batch_size_max = 32 cf.batch_size_delta = 8 cf.num_epochs = 128 - # cf.num_loader_workers = 1#8 + # additional infos cf.size_token_info = 8 cf.size_token_info_net = 16 @@ -252,7 +252,6 @@ def train() : cf.write_json( wandb) cf.print() - #cf.levels = [114, 123, 137] cf.with_mixed_precision = True # cf.n_size = [36, 1*9*6, 1.*9*12] # in steps x lat_degrees x lon_degrees @@ -262,9 +261,6 @@ def train() : cf.num_samples_validate = 128 cf.num_loader_workers = 1 #8 - cf.years_train = [2021] # list( range( 1980, 2018)) - cf.years_test = [2021] #[2018] - trainer = Trainer_BERT( cf, device).create() trainer.run() diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 3a65f9e..2192e75 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -24,7 +24,6 @@ import datetime from typing import TypeVar import functools -import pdb import pandas as pd import wandb @@ -168,9 +167,6 @@ def run( self, epoch = -1) : model_parameters = filter(lambda p: p.requires_grad, self.model_ddp.parameters()) num_params = sum([np.prod(p.size()) for p in model_parameters]) print( f'Number of trainable parameters: {num_params:,}') - - # test at the beginning as reference - self.model.load_data( NetMode.test, batch_size=cf.batch_size_test) if cf.test_initial : cur_test_loss = self.validate( epoch, cf.BERT_strategy).cpu().numpy() @@ -186,7 +182,7 @@ def run( self, epoch = -1) : lr = learn_rates[epoch] for g in self.optimizer.param_groups: g['lr'] = lr - self.model.load_data( NetMode.train, batch_size = cf.batch_size_max) + self.profile() # training loop @@ -204,7 +200,6 @@ def run( self, epoch = -1) : tstr = datetime.datetime.now().strftime("%H:%M:%S") print( '{} : {} :: batch_size = {}, lr = {}'.format( epoch, tstr, batch_size, lr) ) - self.model.load_data( NetMode.train, batch_size = batch_size) self.train( epoch) if cf.with_wandb and 0 == cf.par_rank : @@ -368,7 +363,7 @@ def profile( self): ################################################### def validate( self, epoch, BERT_test_strategy = 'BERT'): - + cf = self.cf BERT_strategy_train = cf.BERT_strategy cf.BERT_strategy = BERT_test_strategy @@ -386,19 +381,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory (sources, token_infos, targets, tmis_list) = batch_data[0] - # breakpoint() - # targets - #TO-DO: implement target - # if len(batch_data[1]) > 0 : - # print("len(batch_data[1])", len(batch_data[1])) - # if type(batch_data[1][0][0]) is list : - # targets = [batch_data[1][i][0][0] for i in range( len(batch_data[1]))] - # else : - # targets = batch_data[1][0] - # targets = [] - # for target_field in batch_data[1] : - # targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) - # store on cpu + log_sources = ( [source.detach().clone().cpu() for source in sources ], [target.detach().clone().cpu() for target in targets ], tmis_list) @@ -464,68 +447,6 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): return total_loss - ################################################### - def evaluate( self, data_idx = 0, log = True): - - cf = self.cf - self.model.mode( NetMode.test) - - log_sources = [] - test_len = 0 - - # evaluate - loss = torch.tensor( 0.) - - with torch.no_grad() : - for it in range( self.model.len( NetMode.test)) : - - batch_data = self.model.next() - - if cf.par_rank < cf.log_test_num_ranks : - # keep on cpu since it will otherwise clog up GPU memory - (sources, token_infos, targets, tmis_list) = batch_data[0] - # targets - if len(batch_data[1]) > 0 : - targets = [] - for target_field in batch_data[1] : - targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) - - # store on cpu - # TODO: is this still all needed with self.sources_idx - log_sources = ( [source.detach().clone().cpu() for source in sources ], - [target.detach().clone().cpu() for target in targets ], - tmis_list) - - batch_data = self.prepare_batch( batch_data) - - preds, atts = self.model( batch_data) - ifield = 0 - for pred, idx in zip( preds, self.fields_prediction_idx) : - - target = self.targets[idx] - cur_loss = self.MSELoss( pred[0], target = target ).cpu() - loss += cur_loss - ifield += 1 - - test_len += 1 - - # logging - if cf.par_rank < cf.log_test_num_ranks : - self.log_validate( data_idx, it, log_sources, preds) - - if cf.attention: - self.log_attention( data_idx , it, atts) - - # average over all nodes - loss /= test_len * len(self.cf.fields_prediction) - if cf.with_ddp : - loss_cuda = loss.cuda() - dist.all_reduce( loss_cuda, op=torch.distributed.ReduceOp.AVG ) - loss = loss_cuda.cpu() - - if 0 == cf.par_rank : - print( 'Loss {}'.format( loss)) - ################################################### def test_loss( self, pred, target) : '''Hook for custom test loss''' @@ -652,7 +573,7 @@ def prepare_batch( self, xin) : assert len(cf.fields[ifield][2]) == 1 tmidx = self.tokens_masked_idx[ifield][0] source[ tmidx ] = self.model.net.masks[ifield].to( source.device) -# breakpoint() + return batch_data ################################################### @@ -694,7 +615,7 @@ def log_validate( self, epoch, bidx, log_sources, log_preds) : if 'forecast' in self.cf.BERT_strategy : self.log_validate_forecast( epoch, bidx, log_sources, log_preds) - elif 'BERT' in self.cf.BERT_strategy : + elif 'BERT' in self.cf.BERT_strategy or 'temporal_interpolation' == self.cf.BERT_strategy : self.log_validate_BERT( epoch, bidx, log_sources, log_preds) else : assert False @@ -702,13 +623,9 @@ def log_validate( self, epoch, bidx, log_sources, log_preds) : ################################################### def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : '''Logging for BERT_strategy=forecast.''' - - # TODO, TODO: use sources_idx cf = self.cf - # TODO, TODO: for 6h forecast we need to iterate over predicted token slices - # save source: remains identical so just save ones (sources, targets, _) = log_sources @@ -726,7 +643,7 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : source = detokenize( sources[fidx].cpu().detach().numpy()) # recover tokenized shape target = detokenize( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, - forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) + forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) coords_b, targ_coords_b = [], [] for bidx in range(batch_size): @@ -738,7 +655,6 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : #TODO: add support for multiple months for vidx, _ in enumerate(field_info[2]) : denormalize = self.model.normalizer( fidx, vidx).denormalize - # breakpoint() source[bidx,vidx] = denormalize( dates[0].year, dates[0].month, source[bidx,vidx], [lats, lons]) target[bidx,vidx] = denormalize( dates_t[0].year, dates_t[0].month, target[bidx,vidx], [lats, lons]) @@ -772,7 +688,6 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : #TODO: add support for multiple months for vidx, vl in enumerate(field_info[2]) : denormalize = self.model.normalizer( self.fields_prediction_idx[fidx], vidx).denormalize - # breakpoint() pred[bidx,vidx] = denormalize( dates_t[0].year, dates_t[0].month, pred[bidx,vidx], [lats, lons]) ensemble[bidx,:,vidx] = denormalize(dates_t[0].year, dates_t[0].month, ensemble[bidx,:,vidx], [lats, lons]) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 399757e..a67f46d 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -62,12 +62,6 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe print( f'self.lons : {self.lons.shape}', flush=True) self.fields_idxs = [] - # TODO - # # create (target) fields - # self.datasets_targets = self.create_loaders( fields_targets) - # self.fields_targets = fields_targets - # self.pre_batch_targets = pre_batch_targets - self.time_sampling = time_sampling self.range_lat = np.array( self.lats[ [0,-1] ]) self.range_lon = np.array( self.lons[ [0,-1] ]) @@ -87,7 +81,7 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe ner = NormalizerGlobal if corr_type == 'global' else NormalizerLocal for vl in field_info[2]: data_type = 'data_sfc' if vl == 0 else 'data' #surface field - self.normalizers[-1] += [ ner( field_info, vl ) ] + self.normalizers[-1] += [ ner( field_info, vl)] #, self.range_lat, self.range_lon, self.res, self.ds_global ) ] # extract indices for selected years self.times = pd.DatetimeIndex( self.ds['time']) self.idxs_years = np.arange( self.ds_len) @@ -216,6 +210,7 @@ def set_data( self, times_pos, batch_size = None) : # generate all the data self.idxs_perm = np.zeros( (len(times_pos), 2)) self.idxs_perm_t = [] + self.num_samples = len(times_pos) for idx, item in enumerate( times_pos) : assert item[2] >= 1 and item[2] <= 31 @@ -278,7 +273,7 @@ def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : # adjust batch size if necessary so that the evaluations split up across batches of equal size batch_size = num_tiles_lon - + print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) self.set_data( times_pos, batch_size) diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py index 622f672..3837a02 100644 --- a/atmorep/datasets/normalizer_global.py +++ b/atmorep/datasets/normalizer_global.py @@ -15,9 +15,7 @@ #################################################################################################### import numpy as np - import atmorep.config.config as config -import pdb class NormalizerGlobal() : diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py index 743cf63..92e58c5 100644 --- a/atmorep/datasets/normalizer_local.py +++ b/atmorep/datasets/normalizer_local.py @@ -21,8 +21,9 @@ class NormalizerLocal() : - def __init__(self, field_info, vlevel, data_type = 'era5', level_type = 'ml') : - + # def __init__(self, field_info, vlevel, range_lat, range_lon, res, is_global, data_type= 'era5', level_type = 'ml') : + def __init__(self, field_info, vlevel, data_type = 'era5', level_type = 'ml') : + fname_base = './data/{}/normalization/{}/normalization_mean_var_{}_y{}_m{:02d}_{}{}.bin' self.year_base = config.datasets[data_type]['extent'][0][0] self.year_last = config.datasets[data_type]['extent'][0][1] diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 0f5b4f5..275af5d 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -18,8 +18,6 @@ import numpy as np from functools import partial import code -import pdb -# from atmorep.utils.utils import tokenize #################################################################################################### def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, fields_infos) : @@ -40,12 +38,6 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, bert_f = prepare_batch_BERT_forecast_field elif BERT_strategy == 'temporal_interpolation' : bert_f = prepare_batch_BERT_temporal_field - elif BERT_strategy == 'forecast_1shot' : - bert_f = prepare_batch_BERT_forecast_field_1shot - elif BERT_strategy == 'identity' : - bert_f = prepare_batch_BERT_identity_field - elif BERT_strategy == 'totalmask' : - bert_f = prepare_batch_BERT_totalmask_field else : assert False @@ -158,11 +150,6 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # unsqueeze(usq()) is required since channel dimension is expected temp = mr( mr( usq( source[ idx[idx_mr_cond] ].reshape( (-1,ts[0],ts[1],ts[2])), 1), mrs), ts) source[ idx[idx_mr_cond] ] = sq( fl( temp, -3, -1)) - # adjust resolution parameter in token_info - # token_info_shape = token_info.shape - # token_info = token_info.flatten( 0, 1) - # token_info[ idx[idx_mr_cond] ][-1] *= (mrs[1] + mrs[2]) / 2. #TODO: anisotropic resolution - # token_info = token_info.reshape( token_info_shape) # recover batch dimension which was flattend for easier indexing and also token dimensions source = torch.reshape( torch.reshape( source, source_shape), source_shape0) @@ -210,75 +197,7 @@ def prepare_batch_BERT_temporal_field( cf, ifield, source, token_info, rng) : num_tokens_space = num_tokens[1] * num_tokens[2] idx_time_mask = int( np.floor(num_tokens[0] / 2.)) # TODO: masking of multiple time steps idxs = idx_time_mask * num_tokens_space + torch.arange(num_tokens_space) - - # collapse token dimensions - source_shape0 = source.shape - source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) - - # linear indices for masking - num_tokens = source.shape[1] - idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) - tokens_masked_idx = idx - - source_shape = source.shape - # flatten along first two dimension to simplify linear indexing (which then requires an - # easily computable row offset) - source = torch.flatten( source, 0, 1) - - # keep masked tokens for loss computation - target = source[idx].clone() - - # masking - global_mean = 0. * torch.mean(source, 0) - source[ idx ] = global_mean - - # recover batch dimension which was flattend for easier indexing - source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - - return (source, token_info, target, tokens_masked_idx, idxs) - -#################################################################################################### -def prepare_batch_BERT_forecast_field_1shot( cf, ifield, source, token_info, rng) : - - nt = 1 # TODO: specify this in config - num_tokens = source.shape[-6:-3] - num_tokens_space = num_tokens[1] * num_tokens[2] - idxs = (num_tokens[0]-nt) * num_tokens_space + torch.arange(num_tokens_space) - - # collapse token dimensions - source_shape0 = source.shape - source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) - - # linear indices for masking - num_tokens = source.shape[1] - # mask only every second neighborhood: 1 shot setting - idx = torch.cat( [idxs + num_tokens * i for i in range( 1, source.shape[0], 2 )] ) - tokens_masked_idx = idx - - source_shape = source.shape - # flatten along first two dimension to simplify linear indexing (which then requires an - # easily computable row offset) - source = torch.flatten( source, 0, 1) - - # keep masked tokens for loss computation - target = source[idx].clone() - - # masking - global_mean = 0. * torch.mean(source, 0) - source[ idx ] = global_mean - - # recover batch dimension which was flattend for easier indexing - source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - - return (source, token_info, target, tokens_masked_idx, idxs) - -#################################################################################################### -def prepare_batch_BERT_totalmask_field( cf, ifield, source, token_info, rng) : - - num_tokens = source.shape[-6:-3] - num_tokens_space = num_tokens[1] * num_tokens[2] - idxs = torch.arange(num_tokens[0] * num_tokens_space) - + # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) @@ -286,8 +205,7 @@ def prepare_batch_BERT_totalmask_field( cf, ifield, source, token_info, rng) : # linear indices for masking num_tokens = source.shape[1] idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) - tokens_masked_idx = idx - + tokens_masked_idx_list = [idxs + num_tokens * i for i in range( source.shape[0] )] source_shape = source.shape # flatten along first two dimension to simplify linear indexing (which then requires an # easily computable row offset) @@ -303,9 +221,4 @@ def prepare_batch_BERT_totalmask_field( cf, ifield, source, token_info, rng) : # recover batch dimension which was flattend for easier indexing source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx) - -#################################################################################################### -def prepare_batch_BERT_identity_field( cf, ifield, source, token_info, rng) : - - return (source, token_info, None, None, None) + return (source, token_info, target, tokens_masked_idx_list) From f2b064678ea99334aeccfb3135fe4a6f5bfb3a83 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 13 Apr 2024 11:28:18 +0200 Subject: [PATCH 25/66] Fixed hard coded path. --- atmorep/core/evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index c3588e2..681e3d6 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -89,7 +89,7 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : print( 'Running Evaluate.evaluate with mode =', mode) cf.num_loader_workers = cf.loader_num_workers - cf.data_dir = './data/' + cf.data_dir = config.path_data cf.rng_seed = None #backward compatibility From 008031b8b83ad4fca2062952ff0b784c5660928c Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sat, 13 Apr 2024 12:36:03 +0200 Subject: [PATCH 26/66] Removed load_data for training. --- atmorep/core/trainer.py | 4 ---- atmorep/datasets/multifield_data_sampler.py | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 7c268da..ee36e47 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -168,9 +168,6 @@ def run( self, epoch = -1) : model_parameters = filter(lambda p: p.requires_grad, self.model_ddp.parameters()) num_params = sum([np.prod(p.size()) for p in model_parameters]) print( f'Number of trainable parameters: {num_params:,}') - - # test at the beginning as reference - self.model.load_data( NetMode.test, batch_size=cf.batch_size_test) if cf.test_initial : cur_test_loss = self.validate( epoch, cf.BERT_strategy).cpu().numpy() @@ -204,7 +201,6 @@ def run( self, epoch = -1) : tstr = datetime.datetime.now().strftime("%H:%M:%S") print( '{} : {} :: batch_size = {}, lr = {}'.format( epoch, tstr, batch_size, lr) ) - self.model.load_data( NetMode.train, batch_size = batch_size) self.train( epoch) if cf.with_wandb and 0 == cf.par_rank : diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 6a87809..24510dd 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -82,9 +82,7 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.range_lat = np.array( self.lats[ [0,-1] ]) self.range_lon = np.array( self.lons[ [0,-1] ]) - self.res = np.zeros( 2) - self.res[0] = [0] - self.res[1] = self.ds.attrs['resol'][1] + self.res = np.array(self.ds.attrs['resol']) # ensure neighborhood does not exceed domain (either at pole or for finite domains) self.range_lat += np.array([n_size[1] / 2., -n_size[1] / 2.]) From f9b609a802c58e12e8a461ba95710f47a3780a01 Mon Sep 17 00:00:00 2001 From: iluise Date: Tue, 16 Apr 2024 17:07:14 +0200 Subject: [PATCH 27/66] new normalization from zarr --- atmorep/config/config.py | 26 +------ atmorep/core/atmorep_model.py | 21 +++--- atmorep/core/evaluate.py | 14 ++-- atmorep/core/evaluator.py | 4 +- atmorep/core/train.py | 2 +- atmorep/core/train_multi.py | 2 +- atmorep/core/trainer.py | 61 ++++++++------- atmorep/datasets/multifield_data_sampler.py | 52 ++++++++----- atmorep/datasets/normalizer.py | 82 +++++++++++++++++++++ 9 files changed, 174 insertions(+), 90 deletions(-) create mode 100644 atmorep/datasets/normalizer.py diff --git a/atmorep/config/config.py b/atmorep/config/config.py index 087dc33..6c6e20d 100644 --- a/atmorep/config/config.py +++ b/atmorep/config/config.py @@ -3,12 +3,9 @@ fpath = os.path.dirname(os.path.realpath(__file__)) -# year_base = 1979 -# year_last = 2022 - +path_data = Path(fpath, '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_final.zarr') path_models = Path( fpath, '../../models/') path_results = Path( fpath, '../../results') -path_data = Path( fpath, '../../data/') path_plots = Path( fpath, '../results/plots/') grib_index = { 'vorticity' : 'vo', 'divergence' : 'd', 'geopotential' : 'z', @@ -17,24 +14,3 @@ 'velocity_u' : 'u', 'velocity_v': 'v', 'velocity_z' : 'w', 'total_precip' : 'tp', 'radar_precip' : 'yw_hourly', 't2m' : 't_2m', 'u_10m' : 'u_10m', 'v_10m' : 'v_10m', } - -# TODO: extract this info from the datasets -datasets = {} -# -datasets['era5'] = {} -datasets['era5']['resolution'] = [1, 0.25, 0.25] -datasets['era5']['extent'] = [ [1979, 2022], [90., -90], [0.0, 360] ] -datasets['era5']['is_global'] = True -datasets['era5']['file_size'] = [ -1, 721, 1440] -# -datasets['cosmo_rea6'] = {} -datasets['cosmo_rea6']['resolution'] = [1, 0.0625, 0.0625] -datasets['cosmo_rea6']['extent'] = [ [1997, 2017], [27.5,70.25], [-12.5,37.0] ] -datasets['cosmo_rea6']['is_global'] = False -datasets['cosmo_rea6']['file_size'] = [ -1, 685, 793] -# -datasets['cerra'] = {} -datasets['cerra']['resolution'] = [3, 0.25, 0.25] -datasets['cerra']['extent'] = [ [1985, 2001], [75.25,20.5], [-58.0,74.0] ] -datasets['cerra']['is_global'] = False -datasets['cerra']['file_size'] = [ -1, 220, 529] diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 97f03de..13a7248 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -113,7 +113,7 @@ def _set_data( self, dataset, mode : NetMode, batch_size = -1, loader_workers = assert False ################################################### - def normalizer( self, field, vl_idx) : + def normalizer( self, field, vl_idx, lats_idx, lons_idx ) : if isinstance( field, str) : for fidx, field_info in enumerate(self.cf.fields) : @@ -124,22 +124,25 @@ def normalizer( self, field, vl_idx) : elif isinstance( field, int) : normalizer = self.dataset_train.normalizers[field][vl_idx] - + if len(normalizer.shape) > 2: + normalizer = np.take( np.take( normalizer, lats_idx, -2), lons_idx, -1) else : assert False, 'invalid argument type (has to be index to cf.fields or field name)' + + year_base = self.dataset_train.year_base - return normalizer + return normalizer, year_base ################################################### def mode( self, mode : NetMode) : if mode == NetMode.train : - self.data_loader_iter = iter(self.data_loader_train) - #self.data_loader_iter = iter(self.dataset_train) + #self.data_loader_iter = iter(self.data_loader_train) + self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : - self.data_loader_iter = iter(self.data_loader_test) - #self.data_loader_iter = iter(self.dataset_test) + #self.data_loader_iter = iter(self.data_loader_test) + self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -183,14 +186,14 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False, 'num_workers': cf.num_loader_workers, 'pin_memory': True} - self.dataset_train = MultifieldDataSampler( cf.fields, cf.years_train, + self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_per_epoch, with_shuffle = (cf.BERT_strategy != 'global_forecast') ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) - self.dataset_test = MultifieldDataSampler( cf.fields, cf.years_test, + self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_test, cf.batch_size_start, pre_batch, cf.n_size, cf.num_samples_validate, with_shuffle = (cf.BERT_strategy != 'global_forecast'), diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index d785e86..bf596b8 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -19,10 +19,10 @@ if __name__ == '__main__': # models for individual fields - #model_id = '4nvwbetz' # vorticity + model_id = '4nvwbetz' # vorticity #model_id = 'oxpycr7w' # divergence # model_id = '1565pb1f' # specific_humidity - # model_id = '3kdutwqb' # total precip + #model_id = '3kdutwqb' # total precip #model_id = 'dys79lgw' # velocity_u #model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z @@ -30,7 +30,7 @@ #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence - model_id = '1jh2qvrx' # multiformer, velocity + # model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity # model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting @@ -46,15 +46,15 @@ #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} # BERT forecast mode - mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } - + #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } + #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } #temporal interpolation - #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'attention' : False } + mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'attention' : False } # BERT forecast with patching to obtain global forecast # mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], # 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], - # 'token_overlap' : [2, 6], + # 'token_overlap' : [0, 0], # 'forecast_num_tokens' : 2, # 'attention' : False } now = time.time() diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index c3588e2..021edb0 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -82,14 +82,14 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : cf.par_rank = par_rank cf.par_size = par_size # overwrite old config - cf.data_dir = str(config.path_data) + cf.file_path = str(config.path_data) cf.attention = False setup_wandb( cf.with_wandb, cf, par_rank, '', mode='offline') if 0 == cf.par_rank : print( 'Running Evaluate.evaluate with mode =', mode) cf.num_loader_workers = cf.loader_num_workers - cf.data_dir = './data/' + #cf.data_dir = './data/' cf.rng_seed = None #backward compatibility diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 91cf03b..1a1ab49 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -102,7 +102,7 @@ def train() : # general cf.comment = '' cf.file_format = 'grib' - cf.data_dir = str(config.path_data) + cf.file_path = str(config.path_data) cf.level_type = 'ml' # format: list of fields where for each field the list is diff --git a/atmorep/core/train_multi.py b/atmorep/core/train_multi.py index 069c425..1bb8fcf 100644 --- a/atmorep/core/train_multi.py +++ b/atmorep/core/train_multi.py @@ -86,7 +86,7 @@ def train_multi() : # general cf.comment = '' cf.file_format = 'grib' - cf.data_dir = str(config.path_data) + cf.file_path = str(config.path_data) cf.level_type = 'ml' cf.fields = [ [ 'vorticity', [ 1, 2048, ['divergence', 'temperature'], 0 ], diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 2192e75..80e0f42 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -50,6 +50,7 @@ from atmorep.utils.utils import sgn_exp from atmorep.utils.utils import tokenize, detokenize from atmorep.datasets.data_writer import write_forecast, write_BERT, write_attention +from atmorep.datasets.normalizer import denormalize #################################################################################################### class Trainer_Base() : @@ -651,12 +652,15 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : lats = self.sources_info[bidx][1] lons = self.sources_info[bidx][2] dates_t = self.sources_info[bidx][0][ -forecast_num_tokens*field_info[4][0] : ] - + + lats_idx = self.sources_idxs[bidx][1] + lons_idx = self.sources_idxs[bidx][2] + #TODO: add support for multiple months for vidx, _ in enumerate(field_info[2]) : - denormalize = self.model.normalizer( fidx, vidx).denormalize - source[bidx,vidx] = denormalize( dates[0].year, dates[0].month, source[bidx,vidx], [lats, lons]) - target[bidx,vidx] = denormalize( dates_t[0].year, dates_t[0].month, target[bidx,vidx], [lats, lons]) + normalizer, year_base = self.model.normalizer( fidx, vidx, lats_idx, lons_idx) + source[bidx,vidx] = denormalize( source[bidx,vidx], normalizer, dates, year_base) + target[bidx,vidx] = denormalize( target[bidx,vidx], normalizer, dates_t, year_base) coords_b += [[dates, 90.-lats, lons, dates_t]] @@ -667,7 +671,6 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : # process predicted fields for fidx, fn in enumerate(cf.fields_prediction) : - # field_info = cf.fields[ self.fields_prediction_idx[fidx] ] num_levels = len(field_info[2]) # predictions @@ -685,11 +688,10 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : lons = self.sources_info[bidx][2] dates_t = self.sources_info[bidx][0][ -forecast_num_tokens*field_info[4][0] : ] - #TODO: add support for multiple months for vidx, vl in enumerate(field_info[2]) : - denormalize = self.model.normalizer( self.fields_prediction_idx[fidx], vidx).denormalize - pred[bidx,vidx] = denormalize( dates_t[0].year, dates_t[0].month, pred[bidx,vidx], [lats, lons]) - ensemble[bidx,:,vidx] = denormalize(dates_t[0].year, dates_t[0].month, ensemble[bidx,:,vidx], [lats, lons]) + normalizer, year_base = self.model.normalizer( self.fields_prediction_idx[fidx], vidx, lats_idx, lons_idx) + pred[bidx,vidx] = denormalize( pred[bidx,vidx], normalizer, dates_t, year_base) + ensemble[bidx,:,vidx] = denormalize(ensemble[bidx,:,vidx].swapaxes(0,1), normalizer, dates_t, year_base).swapaxes(0,1) # append preds_out.append( [fn[0], pred]) @@ -761,7 +763,10 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : dates = self.sources_info[bidx][0] lats = self.sources_info[bidx][1] lons = self.sources_info[bidx][2] - + + lats_idx = self.sources_idxs[bidx][1] + lons_idx = self.sources_idxs[bidx][2] + # target etc are aliasing targets_b which simplifies bookkeeping below if is_predicted : target = [targets_b[vidx][bidx] for vidx in range(num_levels)] @@ -771,36 +776,42 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : coords_mskd_l = [] for vidx, _ in enumerate(field_info[2]) : - normalizer = self.model.normalizer( fidx, vidx) - y, m = dates[0].year, dates[0].month - sources_b[bidx,vidx] = normalizer.denormalize( y, m, sources_b[bidx,vidx], [lats, lons]) + normalizer, year_base = self.model.normalizer( fidx, vidx, lats_idx, lons_idx) + sources_b[bidx,vidx] = denormalize(sources_b[bidx,vidx], normalizer, dates, year_base = 2021) if is_predicted : idx = tokens_masked_idx_list[fidx][vidx][bidx] - grid = np.flip(np.array( np.meshgrid( self.sources_info[bidx][2], self.sources_info[bidx][1])), axis = 0) #flip to have lat on pos 0 and lon on pos 1 - + grid = np.flip(np.array( np.meshgrid( lons, lats)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + grid_idx = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + # recover time dimension since idx assumes the full space-time cube grid = torch.from_numpy( np.array( np.broadcast_to( grid, shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) - grid_lons_toked = tokenize( grid[0], token_size).flatten( 0, 2) - + grid_lons_toked = tokenize( grid[1], token_size).flatten( 0, 2) + idx_loc = idx - np.prod(num_tokens) * bidx #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] lats_mskd = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) lons_mskd = np.array([np.unique(t) for t in grid_lons_toked[ idx_loc ].numpy()]) - + #time: idx ranges from 0->863 12x6x12 t_idx = (idx_loc // (num_tokens[1]*num_tokens[2])) * token_size[0] #create range from t_idx-2 to t_idx t_idx = np.array([np.arange(t, t + token_size[0]) for t in t_idx]) - dates_mskd = self.sources_info[bidx][0][t_idx] - - for ii,(t,p,e,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], - lats_mskd, lons_mskd)) : - targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) - preds_mu_b[vidx][bidx][ii] = normalizer.denormalize( y, m, p, [la, lo]) - preds_ens_b[vidx][bidx][ii] = normalizer.denormalize( y, m, e, [la, lo]) + dates_mskd = dates[t_idx] + + for ii,(t,p,e,da,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], + dates_mskd, lats_mskd, lons_mskd)) : + normalizer_ii = normalizer + if len(normalizer.shape) > 2: #local normalization + lats_mskd_idx = np.where(np.isin(lats,la))[0] + lons_mskd_idx = np.where(np.isin(lons,lo))[0] + normalizer_ii = normalizer[:, :, lats_mskd_idx, lons_mskd_idx] + + targets_b[vidx][bidx][ii] = denormalize(t, normalizer_ii, da, year_base) + preds_mu_b[vidx][bidx][ii] = denormalize(p, normalizer_ii, da, year_base) + preds_ens_b[vidx][bidx][ii] = denormalize(e, normalizer_ii, da, year_base) coords_mskd_l += [[dates_mskd, 90.-lats_mskd, lons_mskd] ] diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index a67f46d..eccb75e 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -21,14 +21,15 @@ from datetime import datetime import time -from atmorep.datasets.normalizer_global import NormalizerGlobal -from atmorep.datasets.normalizer_local import NormalizerLocal +# from atmorep.datasets.normalizer_global import NormalizerGlobal +# from atmorep.datasets.normalizer_local import NormalizerLocal +from atmorep.datasets.normalizer import normalize from atmorep.utils.utils import tokenize class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### - def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_per_epoch, + def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, num_samples_per_epoch, with_shuffle = False, time_sampling = 1, with_source_idxs = False, fields_targets = None, pre_batch_targets = None ) : ''' @@ -47,16 +48,16 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.pre_batch = pre_batch - fname_source = '/p/scratch/atmo-rep/data/era5_1deg/era5_res0025_2021_final.zarr' - self.ds = zarr.open( fname_source) + #fname_source = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_final.zarr' + self.ds = zarr.open( file_path) self.ds_global = self.ds.attrs['is_global'] - self.ds_len = self.ds['data'].shape[0] self.lats = np.array( self.ds['lats']) self.lons = np.array( self.ds['lons']) sh = self.ds['data'].shape st = self.ds['time'].shape + self.ds_len = st[0] #self.ds['data'].shape[2] print( f'self.ds[\'data\'] : {sh} :: {st}') print( f'self.lats : {self.lats.shape}', flush=True) print( f'self.lons : {self.lons.shape}', flush=True) @@ -65,8 +66,8 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe self.time_sampling = time_sampling self.range_lat = np.array( self.lats[ [0,-1] ]) self.range_lon = np.array( self.lons[ [0,-1] ]) - self.res = np.array(self.ds.attrs['resol']) - + self.res = np.array(self.ds.attrs['res']) + self.year_base = self.ds['time'][0].astype(datetime).year # ensure neighborhood does not exceed domain (either at pole or for finite domains) self.range_lat += np.array([n_size[1] / 2., -n_size[1] / 2.]) # lon: no change for periodic case @@ -75,13 +76,19 @@ def __init__( self, fields, years, batch_size, pre_batch, n_size, num_samples_pe # data normalizers self.normalizers = [] - for _, field_info in enumerate(fields) : - self.normalizers.append( []) + for ifield, field_info in enumerate(fields) : corr_type = 'global' if len(field_info) <= 6 else field_info[6] - ner = NormalizerGlobal if corr_type == 'global' else NormalizerLocal + nf_name = 'global_norm' if corr_type == 'global' else 'norm' + self.normalizers.append( [] ) for vl in field_info[2]: - data_type = 'data_sfc' if vl == 0 else 'data' #surface field - self.normalizers[-1] += [ ner( field_info, vl)] #, self.range_lat, self.range_lon, self.res, self.ds_global ) ] + if vl == 0: + field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) + self.normalizers[ifield] += [self.ds[f'normalization/{nf_name}_sfc'].oindex[ :, :, field_idx]] + else: + vl_idx = self.ds.attrs['levels'].index(vl) + field_idx = self.ds.attrs['fields'].index( field_info[0]) + self.normalizers[ifield] += [self.ds[f'normalization/{nf_name}'].oindex[ :, :, field_idx, vl_idx]] + # extract indices for selected years self.times = pd.DatetimeIndex( self.ds['time']) self.idxs_years = np.arange( self.ds_len) @@ -153,23 +160,28 @@ def __iter__(self): for ifield, field_info in enumerate(self.fields): source_lvl, tok_info_lvl = [], [] - tok_size = field_info[4] + tok_size = field_info[4] + corr_type = 'global' if len(field_info) <= 6 else field_info[6] + for ilevel, vl in enumerate(field_info[2]): if vl == 0 : #surface level field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) - data_t = self.ds['data_sfc'].oindex[ idxs_t, field_idx] + data_t = self.ds['data_sfc'].oindex[field_idx, idxs_t] else : field_idx = self.ds.attrs['fields'].index( field_info[0]) vl_idx = self.ds.attrs['levels'].index(vl) - data_t = self.ds['data'].oindex[ idxs_t, field_idx, vl_idx] - - nf = self.normalizers[ifield][ilevel].normalize + data_t = self.ds['data'].oindex[ field_idx, vl_idx, idxs_t] + source_data, tok_info = [], [] - # extract data, normalize and tokenize cdata = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) - cdata = nf( year, month, cdata, (lats[lat_ran], lons[lon_ran]) ) + + normalizer = self.normalizers[ifield][ilevel] + if corr_type != 'global': + normalizer = np.take( np.take( normalizer, lat_ran, -2), lon_ran, -1) + + cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) source_data = tokenize( torch.from_numpy( cdata), tok_size ) # token_infos uses center of the token: *last* datetime and center in space diff --git a/atmorep/datasets/normalizer.py b/atmorep/datasets/normalizer.py new file mode 100644 index 0000000..0487af2 --- /dev/null +++ b/atmorep/datasets/normalizer.py @@ -0,0 +1,82 @@ +#################################################################################################### +# +# Copyright (C) 2022 +# +#################################################################################################### +# +# project : atmorep +# +# author : atmorep collaboration +# +# description : +# +# license : +# +#################################################################################################### + +import code +import numpy as np +import xarray as xr +import atmorep.config.config as config + +###################################################### +# Normalize # +###################################################### + +def normalize( data, norm, dates, year_base = 1979) : + corr_data = np.array([norm[12*(dt.year-year_base) + dt.month-1] for dt in dates]) + mean, var = corr_data[:, 0], corr_data[:, 1] + if (var == 0.).all() : + print( f'Warning: var == 0') + assert False + if len(norm.shape) > 2 : #global norm + return normalize_local(data, mean, var) + else: + return normalize_global( data, mean, var) + +###################################################### +def normalize_local( data, mean, var) : + data = (data - mean) / var + return data + +###################################################### +def normalize_global( data, mean, var) : + for i in range( data.shape[0]) : + data[i] = (data[i] - mean[i]) / var[i] + return data + + +###################################################### +# Denormalize # +###################################################### +def denormalize(data, norm, dates, year_base = 1979) : + corr_data = np.array([norm[12*(dt.year-year_base) + dt.month-1] for dt in dates]) + mean, var = corr_data[:, 0], corr_data[:, 1] + if len(norm.shape) > 2 : + return denormalize_local(data, mean, var) + else: + return denormalize_global(data, mean, var) + +###################################################### + +def denormalize_local(data, mean, var) : + if len(data.shape) > 3: #ensemble + for i in range( data.shape[0]) : + data[i] = (data[i] * var) + mean + else: + data = (data * var) + mean + return data + +###################################################### + +def denormalize_global(data, mean, var) : + if len(data.shape) > 3: #ensemble + data = data.swapaxes(0,1) + for i in range( data.shape[0]) : + data[i] = ((data[i] * var[i]) + mean[i]) + data = data.swapaxes(0,1) + else: + for i in range( data.shape[0]) : + data[i] = (data[i] * var[i]) + mean[i] + + return data \ No newline at end of file From ecc6621ff69cb4292ca0b76ceca3c0bb7b49e32a Mon Sep 17 00:00:00 2001 From: iluise Date: Tue, 16 Apr 2024 17:41:07 +0200 Subject: [PATCH 28/66] validated status --- atmorep/core/evaluate.py | 16 ++++++++-------- atmorep/core/trainer.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index bf596b8..3decd4f 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -19,7 +19,7 @@ if __name__ == '__main__': # models for individual fields - model_id = '4nvwbetz' # vorticity + #model_id = '4nvwbetz' # vorticity #model_id = 'oxpycr7w' # divergence # model_id = '1565pb1f' # specific_humidity #model_id = '3kdutwqb' # total precip @@ -30,7 +30,7 @@ #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence - # model_id = '1jh2qvrx' # multiformer, velocity + model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity # model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting @@ -49,14 +49,14 @@ #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } #temporal interpolation - mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'attention' : False } + #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'attention' : False } # BERT forecast with patching to obtain global forecast - # mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], - # 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], - # 'token_overlap' : [0, 0], - # 'forecast_num_tokens' : 2, - # 'attention' : False } + mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], + 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], + 'token_overlap' : [0, 0], + 'forecast_num_tokens' : 2, + 'attention' : False } now = time.time() Evaluator.evaluate( mode, model_id, options) print("time", time.time() - now) \ No newline at end of file diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 80e0f42..a925666 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -691,7 +691,7 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : for vidx, vl in enumerate(field_info[2]) : normalizer, year_base = self.model.normalizer( self.fields_prediction_idx[fidx], vidx, lats_idx, lons_idx) pred[bidx,vidx] = denormalize( pred[bidx,vidx], normalizer, dates_t, year_base) - ensemble[bidx,:,vidx] = denormalize(ensemble[bidx,:,vidx].swapaxes(0,1), normalizer, dates_t, year_base).swapaxes(0,1) + ensemble[bidx,:,vidx] = denormalize(ensemble[bidx,:,vidx], normalizer, dates_t, year_base) # append preds_out.append( [fn[0], pred]) From e56952f12ea5d8c89d46b00b0a2a4a844fd35d03 Mon Sep 17 00:00:00 2001 From: iluise Date: Fri, 19 Apr 2024 15:54:08 +0200 Subject: [PATCH 29/66] temporal interpolation for multiple time steps --- atmorep/core/evaluate.py | 7 ++++--- atmorep/core/trainer.py | 1 - atmorep/datasets/multifield_data_sampler.py | 4 +--- atmorep/training/bert.py | 11 ++++++++--- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 3decd4f..3528b19 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -19,7 +19,7 @@ if __name__ == '__main__': # models for individual fields - #model_id = '4nvwbetz' # vorticity + model_id = '4nvwbetz' # vorticity #model_id = 'oxpycr7w' # divergence # model_id = '1565pb1f' # specific_humidity #model_id = '3kdutwqb' # total precip @@ -30,7 +30,7 @@ #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence - model_id = '1jh2qvrx' # multiformer, velocity + #model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity # model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting @@ -49,7 +49,8 @@ #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } #temporal interpolation - #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'attention' : False } + #idx_time_mask: list of relative time positions of the masked tokens within the cube wrt num_tokens[0] + #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'idx_time_mask': [5,6,7], 'attention' : False } # BERT forecast with patching to obtain global forecast mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index a925666..831fa27 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -656,7 +656,6 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : lats_idx = self.sources_idxs[bidx][1] lons_idx = self.sources_idxs[bidx][2] - #TODO: add support for multiple months for vidx, _ in enumerate(field_info[2]) : normalizer, year_base = self.model.normalizer( fidx, vidx, lats_idx, lons_idx) source[bidx,vidx] = denormalize( source[bidx,vidx], normalizer, dates, year_base) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index eccb75e..3065808 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -47,8 +47,6 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, num self.with_shuffle = with_shuffle self.pre_batch = pre_batch - - #fname_source = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_final.zarr' self.ds = zarr.open( file_path) self.ds_global = self.ds.attrs['is_global'] @@ -57,7 +55,7 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, num sh = self.ds['data'].shape st = self.ds['time'].shape - self.ds_len = st[0] #self.ds['data'].shape[2] + self.ds_len = st[0] print( f'self.ds[\'data\'] : {sh} :: {st}') print( f'self.lats : {self.lats.shape}', flush=True) print( f'self.lons : {self.lons.shape}', flush=True) diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 275af5d..a8c5bc3 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -194,9 +194,14 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : def prepare_batch_BERT_temporal_field( cf, ifield, source, token_info, rng) : num_tokens = source.shape[-6:-3] - num_tokens_space = num_tokens[1] * num_tokens[2] - idx_time_mask = int( np.floor(num_tokens[0] / 2.)) # TODO: masking of multiple time steps - idxs = idx_time_mask * num_tokens_space + torch.arange(num_tokens_space) + num_tokens_space = num_tokens[1] * num_tokens[2] + + #backward compatibility: mask only middle token + if not hasattr( cf, 'idx_time_mask'): + idx_time_mask = int( np.floor(num_tokens[0] / 2.)) + idxs = idx_time_mask * num_tokens_space + torch.arange(num_tokens_space) + else: #list of idx_time_mask + idxs = torch.concat([i*num_tokens_space + torch.arange(num_tokens_space) for i in cf.idx_time_mask]) # collapse token dimensions source_shape0 = source.shape From afbdf047048f1b7a192fc19f6fa114f015967ac4 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 22 Apr 2024 07:23:42 +0200 Subject: [PATCH 30/66] Adding sample/sec to console output. --- atmorep/core/trainer.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 2192e75..562067a 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -17,6 +17,7 @@ import torch import torchinfo import numpy as np +import time import code from pathlib import Path @@ -176,13 +177,10 @@ def run( self, epoch = -1) : test_loss = np.array( [1.0]) epoch += 1 - batch_size = cf.batch_size_start - cf.batch_size_delta - if cf.profile : lr = learn_rates[epoch] for g in self.optimizer.param_groups: g['lr'] = lr - self.profile() # training loop @@ -195,10 +193,8 @@ def run( self, epoch = -1) : for g in self.optimizer.param_groups: g['lr'] = lr - batch_size = min( cf.batch_size_max, batch_size + cf.batch_size_delta) - tstr = datetime.datetime.now().strftime("%H:%M:%S") - print( '{} : {} :: batch_size = {}, lr = {}'.format( epoch, tstr, batch_size, lr) ) + print( '{} : {} :: batch_size = {}, lr = {}'.format( epoch, tstr, cf.batch_size, lr) ) self.train( epoch) @@ -238,6 +234,7 @@ def train( self, epoch): ctr = 0 self.optimizer.zero_grad() + time_start = time.time() for batch_idx in range( model.len( NetMode.train)) : @@ -261,7 +258,7 @@ def train( self, epoch): # logging - if int((batch_idx * cf.batch_size_max) / 4) > ctr : + if int((batch_idx * cf.batch_size) / 4) > ctr : # wandb logging if cf.with_wandb and (0 == cf.par_rank) : @@ -279,14 +276,16 @@ def train( self, epoch): wandb.log( loss_dict ) # console output - print('train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:1.5f} : {:1.5f} :: {:1.5f}'.format( - epoch, batch_idx, model.len( NetMode.train), - 100. * batch_idx/model.len(NetMode.train), - torch.mean( torch.tensor( grad_loss_total)), torch.mean(torch.tensor(mse_loss_total)), - torch.mean( preds[0][1]) ), flush=True) + samples_sec = cf.batch_size / (time.time() - time_start) + str = 'epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:1.5f} : {:1.5f} :: {:1.5f} ({:2.2f} s/sec)' + print( str.format( epoch, batch_idx, model.len( NetMode.train), + 100. * batch_idx/model.len(NetMode.train), + torch.mean( torch.tensor( grad_loss_total)), + torch.mean(torch.tensor(mse_loss_total)), + torch.mean( preds[0][1]), samples_sec ), flush=True) # save model (use -2 as epoch to indicate latest, stored without epoch specification) - # self.save( -2) + self.save( -2) # reset loss_total = [[] for i in range(len(cf.losses)) ] @@ -295,6 +294,7 @@ def train( self, epoch): std_dev_total = [[] for i in range(len(self.fields_prediction_idx)) ] ctr += 1 + time_start = time.time() # save gradients if cf.save_grads and cf.with_wandb and (0 == cf.par_rank) : @@ -353,8 +353,8 @@ def profile( self): loss, mse_loss, losses = self.loss( preds, batch_idx) self.optimizer.zero_grad() - # loss.backward() - # self.optimizer.step() + loss.backward() + self.optimizer.step() prof.step() From 77a2aac64629d106dfc81a35228d079df9d2de05 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 22 Apr 2024 09:15:49 +0200 Subject: [PATCH 31/66] Removed variable length batch size. --- atmorep/core/atmorep_model.py | 12 ++++++------ atmorep/core/train.py | 34 ++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 13a7248..e4cae9e 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -137,12 +137,12 @@ def normalizer( self, field, vl_idx, lats_idx, lons_idx ) : def mode( self, mode : NetMode) : if mode == NetMode.train : - #self.data_loader_iter = iter(self.data_loader_train) - self.data_loader_iter = iter(self.dataset_train) + self.data_loader_iter = iter(self.data_loader_train) + # self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : - #self.data_loader_iter = iter(self.data_loader_test) - self.data_loader_iter = iter(self.dataset_test) + self.data_loader_iter = iter(self.data_loader_test) + # self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -187,14 +187,14 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non 'num_workers': cf.num_loader_workers, 'pin_memory': True} self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, - cf.batch_size_start, + cf.batch_size, pre_batch, cf.n_size, cf.num_samples_per_epoch, with_shuffle = (cf.BERT_strategy != 'global_forecast') ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_test, - cf.batch_size_start, + cf.batch_size_validation, pre_batch, cf.n_size, cf.num_samples_validate, with_shuffle = (cf.BERT_strategy != 'global_forecast'), with_source_idxs = True ) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 1a1ab49..7c0bc45 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -114,9 +114,10 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - # cf.fields = [ [ 'vorticity', [ 1, 2048, [ ], 0 ], - # [ 123 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + cf.fields = [ [ 'vorticity', [ 1, 2048, [ ], 0 ], + [ 123 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] # cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0], # [ 96, 105, 114, 123, 137 ], @@ -150,20 +151,22 @@ def train() : # cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0], # [ 114, 123, 137 ], # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], - cf.fields = [ - [ 'velocity_v', [ 1, 1024, [ ], 0 ], - [ 114, 123, 137 ], - [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ], - [ 'total_precip', [ 1, 1536, [ ], 3 ], - [ 0 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] + # cf.fields = [ + # [ 'velocity_v', [ 1, 1024, [ ], 0 ], + # [ 114, 123, 137 ], + # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ], + # [ 'total_precip', [ 1, 1536, [ ], 3 ], + # [ 0 ], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] #cf.fields_prediction = [ [cf.fields[0][0], 0.33], [cf.fields[1][0], 0.33], [cf.fields[2][0], 0.33]] # cf.fields = [ # [ 'total_precip', [ 1, 1536, [ ], 3 ], # [ 0 ], # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] - cf.fields_prediction = [ [cf.fields[0][0], 0.5],[cf.fields[1][0], 0.5] ] + # cf.fields_prediction = [ [cf.fields[0][0], 0.5],[cf.fields[1][0], 0.5] ] + cf.fields_targets = [] + cf.years_train = [2021] # list( range( 1980, 2018)) cf.years_test = [2021] #[2018] cf.month = None @@ -180,9 +183,8 @@ def train() : # random seeds cf.torch_seed = torch.initial_seed() # training params - cf.batch_size_test = 64 - cf.batch_size_start = 16 - cf.batch_size_max = 32 + cf.batch_size_validation = 64 + cf.batch_size = 32 cf.batch_size_delta = 8 cf.num_epochs = 128 @@ -198,12 +200,12 @@ def train() : cf.learnable_mask = False cf.with_qk_lnorm = True # encoder - cf.encoder_num_layers = 10 + cf.encoder_num_layers = 4 cf.encoder_num_heads = 16 cf.encoder_num_mlp_layers = 2 cf.encoder_att_type = 'dense' # decoder - cf.decoder_num_layers = 10 + cf.decoder_num_layers = 4 cf.decoder_num_heads = 16 cf.decoder_num_mlp_layers = 2 cf.decoder_self_att = False From ffe7e18b5efae06c240c5007f1f578b888af34fe Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 24 Apr 2024 10:59:00 +0200 Subject: [PATCH 32/66] Fixes for new/old ordering of fields. --- atmorep/datasets/multifield_data_sampler.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 0209d12..b708459 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -30,8 +30,8 @@ class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### - def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, num_samples_per_epoch, - with_shuffle = False, time_sampling = 1, with_source_idxs = False, + def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, + num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False, fields_targets = None, pre_batch_targets = None ) : ''' Data set for single dynamic field at an arbitrary number of vertical levels @@ -43,7 +43,7 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, num self.fields = fields self.batch_size = batch_size self.n_size = n_size - self.num_samples = num_samples_per_epoch + self.num_samples = num_samples self.with_source_idxs = with_source_idxs self.with_shuffle = with_shuffle @@ -96,7 +96,6 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, num idxs_years = np.logical_and( idxs_years, self.times.year == year) self.idxs_years = np.where( idxs_years)[0] logging.getLogger('atmorep').info( f'Dataset size for years {years}: {len(self.idxs_years)}.') - # self.idxs_years = np.arange( self.ds_len) ################################################### def shuffle( self) : @@ -161,7 +160,6 @@ def __iter__(self): source_idxs += [ (idxs_t, lat_ran, lon_ran) ] # extract data - year, month = self.times[ idxs_t[-1] ].year, self.times[ idxs_t[-1] ].month for ifield, field_info in enumerate(self.fields): source_lvl, tok_info_lvl = [], [] @@ -172,11 +170,11 @@ def __iter__(self): if vl == 0 : #surface level field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) - data_t = self.ds['data_sfc'].oindex[field_idx, idxs_t] + data_t = self.ds['data_sfc'].oindex[idxs_t, field_idx] else : field_idx = self.ds.attrs['fields'].index( field_info[0]) vl_idx = self.ds.attrs['levels'].index(vl) - data_t = self.ds['data'].oindex[ field_idx, vl_idx, idxs_t] + data_t = self.ds['data'].oindex[ idxs_t, field_idx, vl_idx] source_data, tok_info = [], [] # extract data, normalize and tokenize From 2c73df22d02406530dc727b97241eca5b4049bf8 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Tue, 21 May 2024 11:16:20 +0200 Subject: [PATCH 33/66] Implemented efficient fused flash-attention (i.e. flash attention with heads fused in a common projection matrix). Minor fixes around. --- atmorep/core/atmorep_model.py | 4 +- atmorep/core/train.py | 93 ++++-- atmorep/core/trainer.py | 13 +- atmorep/datasets/multifield_data_sampler.py | 16 +- atmorep/transformer/transformer_attention.py | 330 +++++++++---------- atmorep/transformer/transformer_encoder.py | 7 +- 6 files changed, 247 insertions(+), 216 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index e4cae9e..2e3ff5f 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -189,14 +189,14 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, cf.batch_size, pre_batch, cf.n_size, cf.num_samples_per_epoch, - with_shuffle = (cf.BERT_strategy != 'global_forecast') ) + with_shuffle = (cf.BERT_strategy != 'global_forecast') ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_test, cf.batch_size_validation, pre_batch, cf.n_size, cf.num_samples_validate, - with_shuffle = (cf.BERT_strategy != 'global_forecast'), + with_shuffle = (cf.BERT_strategy != 'global_forecast'), with_source_idxs = True ) self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params, sampler = None) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 7c0bc45..e5220d8 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -86,6 +86,7 @@ def train() : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) + # device = ['cuda'] with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -102,7 +103,6 @@ def train() : # general cf.comment = '' cf.file_format = 'grib' - cf.file_path = str(config.path_data) cf.level_type = 'ml' # format: list of fields where for each field the list is @@ -115,7 +115,8 @@ def train() : # ] cf.fields = [ [ 'vorticity', [ 1, 2048, [ ], 0 ], - [ 123 ], + [ 96, 105, 114, 123, 137 ], + # [ 137 ], [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] cf.fields_prediction = [ [cf.fields[0][0], 1.] ] @@ -148,23 +149,49 @@ def train() : #cf.fields_prediction = [ [cf.fields[0][0], 1.0] ] - # cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0], - # [ 114, 123, 137 ], + # cf.fields = [ [ 'velocity_u', [ 1, 1024, ['velocity_v'], 0], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], + # [ 'velocity_v', [ 1, 1024, ['velocity_u'], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # cf.fields_prediction = [ [cf.fields[0][0], 0.5], [cf.fields[1][0], 0.5] ] + + + # cf.fields = [ [ 'velocity_u', [ 1, 1024, ['velocity_v'], 0], + # [ 96, 105, 114, 123, 137 ], # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], - # cf.fields = [ - # [ 'velocity_v', [ 1, 1024, [ ], 0 ], - # [ 114, 123, 137 ], + # [ 'velocity_v', [ 1, 1024, ['velocity_u'], 0 ], + # [ 96, 105, 114, 123, 137 ], # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ], - # [ 'total_precip', [ 1, 1536, [ ], 3 ], + # [ 'total_precip', [ 1, 1024, ['velocity_u', 'velocity_v'], 0 ], # [ 0 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] - #cf.fields_prediction = [ [cf.fields[0][0], 0.33], [cf.fields[1][0], 0.33], [cf.fields[2][0], 0.33]] - # cf.fields = [ - # [ 'total_precip', [ 1, 1536, [ ], 3 ], - # [ 0 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] - # cf.fields_prediction = [ [cf.fields[0][0], 0.5],[cf.fields[1][0], 0.5] ] + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] + # cf.fields_prediction = [ [cf.fields[0][0], 0.33], [cf.fields[1][0], 0.33], + # [cf.fields[2][0], 0.33]] + cf.fields = [ [ 'vorticity', [ 1, 2048, ['divergence', 'temperature'], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], + [ 'velocity_z', [ 1, 1536, ['vorticity', 'divergence'], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], + [ 'divergence', [ 1, 2048, ['vorticity', 'temperature'], 1 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], + [ 'specific_humidity', [ 1, 2048, ['vorticity', 'divergence', 'velocity_z'], 2 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], + [ 'temperature', [ 1, 1024, ['vorticity', 'divergence'], 3 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], + [ 'total_precip', [ 1, 1536, ['vorticity', 'divergence', 'specific_humidity'], 3 ], + [ 0 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] + cf.fields_prediction = [['vorticity', 0.25], ['velocity_z', 0.1], + ['divergence', 0.25], ['specific_humidity', 0.15], + ['temperature', 0.15], ['total_precip', 0.1] ] + cf.fields_targets = [] cf.years_train = [2021] # list( range( 1980, 2018)) @@ -183,8 +210,8 @@ def train() : # random seeds cf.torch_seed = torch.initial_seed() # training params - cf.batch_size_validation = 64 - cf.batch_size = 32 + cf.batch_size_validation = 1 #64 + cf.batch_size = 16 # 4 # 32 cf.batch_size_delta = 8 cf.num_epochs = 128 @@ -198,7 +225,7 @@ def train() : cf.coupling_num_heads_per_field = 1 cf.dropout_rate = 0.05 cf.learnable_mask = False - cf.with_qk_lnorm = True + cf.with_qk_lnorm = False # encoder cf.encoder_num_layers = 4 cf.encoder_num_heads = 16 @@ -241,7 +268,7 @@ def train() : cf.log_test_num_ranks = 0 cf.save_grads = False cf.profile = False - cf.test_initial = True + cf.test_initial = False cf.attention = False cf.rng_seed = None @@ -250,18 +277,28 @@ def train() : cf.with_wandb = True setup_wandb( cf.with_wandb, cf, par_rank, 'train', mode='offline') - if cf.with_wandb and 0 == cf.par_rank : - cf.write_json( wandb) - cf.print() - cf.with_mixed_precision = True - # cf.n_size = [36, 1*9*6, 1.*9*12] + cf.num_samples_per_epoch = 4096 + cf.num_samples_validate = 128 + cf.num_loader_workers = 8 + + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' + # # in steps x lat_degrees x lon_degrees + cf.n_size = [36, 1*9*6, 1.*9*12] + + # # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' + # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' + # # + # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' + # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' # in steps x lat_degrees x lon_degrees # cf.n_size = [36, 0.25*9*6, 0.25*9*12] - cf.n_size = [36, 0.25*9*6, 0.25*9*12] - cf.num_samples_per_epoch = 1024 - cf.num_samples_validate = 128 - cf.num_loader_workers = 1 #8 + + if cf.with_wandb and 0 == cf.par_rank : + cf.write_json( wandb) + cf.print() trainer = Trainer_BERT( cf, device).create() trainer.run() diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index b8c2a6a..feac614 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -348,14 +348,15 @@ def profile( self): batch_data = self.model.next() - batch_data = self.prepare_batch( batch_data) - preds, _ = self.model_ddp( batch_data) - - loss, mse_loss, losses = self.loss( preds, batch_idx) + with torch.autocast(device_type='cuda',dtype=torch.float16, enabled=cf.with_mixed_precision): + batch_data = self.prepare_batch( batch_data) + preds, _ = self.model_ddp( batch_data) + loss, mse_loss, losses = self.loss( preds, batch_idx) + self.grad_scaler.scale(loss).backward() + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() prof.step() diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index b708459..2021a0a 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -106,7 +106,7 @@ def shuffle( self) : rng_seed = int(time.time()) // (worker_info.id+1) + worker_info.id rng = np.random.default_rng( rng_seed) - self.idxs_perm_t = rng.permutation( self.idxs_years)[ : self.num_samples] + self.idxs_perm_t = rng.permutation( self.idxs_years)[ : self.num_samples // self.batch_size] lats = rng.random(self.num_samples) * (self.range_lat[1] - self.range_lat[0]) +self.range_lat[0] lons = rng.random(self.num_samples) * (self.range_lon[1] - self.range_lon[0]) +self.range_lon[0] @@ -136,9 +136,15 @@ def __iter__(self): sources, token_infos = [[] for _ in self.fields], [[] for _ in self.fields] sources_infos, source_idxs = [], [] + i_bidx = self.idxs_perm_t[bidx] + idxs_t = list(np.arange( bidx - n_size[0]*ts, bidx, ts, dtype=np.int64)) + data_tt_sfc = self.ds['data_sfc'].oindex[idxs_t] + data_tt = self.ds['data'].oindex[idxs_t] + for sidx in range(self.batch_size) : - i_bidx = self.idxs_perm_t[bidx] - idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) + + # i_bidx = self.idxs_perm_t[bidx] + # idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) idx = self.idxs_perm[bidx*self.batch_size+sidx] # slight asymetry with offset by res/2 is required to match desired token count @@ -170,11 +176,11 @@ def __iter__(self): if vl == 0 : #surface level field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) - data_t = self.ds['data_sfc'].oindex[idxs_t, field_idx] + data_t = data_tt_sfc[ :, field_idx ] else : field_idx = self.ds.attrs['fields'].index( field_info[0]) vl_idx = self.ds.attrs['levels'].index(vl) - data_t = self.ds['data'].oindex[ idxs_t, field_idx, vl_idx] + data_t = data_tt[ :, field_idx, vl_idx ] source_data, tok_info = [], [] # extract data, normalize and tokenize diff --git a/atmorep/transformer/transformer_attention.py b/atmorep/transformer/transformer_attention.py index 2143365..59f7b1b 100644 --- a/atmorep/transformer/transformer_attention.py +++ b/atmorep/transformer/transformer_attention.py @@ -31,130 +31,69 @@ class CouplingAttentionMode( Enum) : kv_coupling = 2 -#################################################################################################### - -class AttentionHead(torch.nn.Module): - - def __init__(self, proj_dims, proj_dims_qs = -1, with_qk_lnorm = False, with_attention=False) : - '''Attention head''' - - super(AttentionHead, self).__init__() - - if proj_dims_qs == -1 : - proj_dims_qs = proj_dims[0] - - self.proj_qs = torch.nn.Linear( proj_dims_qs, proj_dims[1], bias = False) - self.proj_ks = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False) - self.proj_vs = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False) - - self.softmax = torch.nn.Softmax(dim=-1) - - if with_qk_lnorm : - self.lnorm_qs = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False) - self.lnorm_ks = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False) - else : - self.lnorm_qs = torch.nn.Identity() - self.lnorm_ks = torch.nn.Identity() - - self.forward = self.forward_attention if with_attention else self.forward_evaluate - - def attention( self, qs, ks) : - '''Compute attention''' - return torch.matmul( qs, torch.transpose( ks, -2, -1)) - - def forward_evaluate( self, xs_q, xs_k_v = None) : - '''Evaluate attention head''' - - xs_k_v = xs_q if None == xs_k_v else xs_k_v - - out_shape = xs_q.shape - qs = self.lnorm_qs( self.proj_qs( torch.flatten( xs_q, 1, -2) )) - ks = self.lnorm_ks( self.proj_ks( torch.flatten( xs_k_v, 1, -2) )) - vs = self.proj_vs( torch.flatten( xs_k_v, 1, -2) ) - # normalization increases interpretability since otherwise the scaling of the values - # interacts with the attention values - # torch.nn.functional.normalize( vs, dim=-1) - - scaling = 1. / torch.sqrt( torch.tensor(qs.shape[2])) - vsp = torch.matmul( self.softmax( scaling * self.attention( qs, ks)), vs) - return (vsp.reshape( [-1] + list(out_shape[1:-1]) + [vsp.shape[-1]]), None) - - def forward_attention( self, xs_q, xs_k_v = None) : - '''Evaluate attention head and also return attention''' - - xs_k_v = xs_q if None == xs_k_v else xs_k_v - - out_shape = xs_q.shape - kv_shape = xs_k_v.shape - qs = self.lnorm_qs( self.proj_qs( torch.flatten( xs_q, 1, -2) )) - ks = self.lnorm_ks( self.proj_ks( torch.flatten( xs_k_v, 1, -2) )) - vs = self.proj_vs( torch.flatten( xs_k_v, 1, -2) ) - # normalization increases interpretability since otherwise the scaling of the values - # interacts with the attention values - # torch.nn.functional.normalize( vs, dim=-1) - - scaling = 1. / torch.sqrt( torch.tensor(qs.shape[2])) - att = self.attention( qs, ks) - vsp = torch.matmul( self.softmax( scaling * att), vs) - return ( vsp.reshape( [-1] + list(out_shape[1:-1]) + [vsp.shape[-1]]), - att.reshape( [-1] + list(out_shape[1:-1]) + list(kv_shape[1:-1])).detach().cpu() ) - -#################################################################################################### - class MultiSelfAttentionHead(torch.nn.Module): - def __init__(self, dim_embed, num_heads, dropout_rate = 0., att_type = 'dense', - with_qk_lnorm = False, grad_checkpointing = False, with_attention = False ) : + ######################################### + def __init__(self, dim_embed, num_heads, dropout_rate=0., with_qk_lnorm=True, with_flash=True) : super(MultiSelfAttentionHead, self).__init__() + self.num_heads = num_heads + self.with_flash = with_flash + assert 0 == dim_embed % num_heads self.dim_head_proj = int(dim_embed / num_heads) self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) + self.proj_heads = torch.nn.Linear( dim_embed, num_heads*3*self.dim_head_proj, bias = False) + self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) + self.dropout = torch.nn.Dropout( p=dropout_rate) if dropout_rate > 0. else torch.nn.Identity() - self.heads = torch.nn.ModuleList() - if 'dense' == att_type : - for n in range( num_heads) : - self.heads.append( AttentionHead( [dim_embed, self.dim_head_proj], - with_qk_lnorm= with_qk_lnorm, with_attention=with_attention)) - elif 'axial' in att_type : - self.heads.append( AxialAttention( dim = dim_embed, dim_index = -1, heads = num_heads, - num_dimensions = 3) ) + lnorm = torch.nn.LayerNorm if with_qk_lnorm else torch.nn.Identity + self.ln_q = lnorm( self.dim_head_proj, elementwise_affine=False) + self.ln_k = lnorm( self.dim_head_proj, elementwise_affine=False) + + # with_flash = False + if with_flash : + self.att = torch.nn.functional.scaled_dot_product_attention else : - assert False, 'Unsuppored attention type.' - - # proj_out is done is axial attention head so do not repeat it - self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) \ - if att_type == 'dense' else torch.nn.Identity() - self.dropout = torch.nn.Dropout( p=dropout_rate) - - self.checkpoint = identity - if grad_checkpointing : - self.checkpoint = checkpoint_wrapper + self.att = self.attention + self.softmax = torch.nn.Softmax(dim=-1) - def forward( self, x, y = None) : + ######################################### + def forward( self, x) : + split, tr = torch.tensor_split, torch.transpose + x_in = x x = self.lnorm( x) - outs, atts = [], [] - for head in self.heads : - y, att = self.checkpoint( head, x) - outs.append( y) - atts.append( y) - outs = torch.cat( outs, -1) - - outs = self.dropout( self.checkpoint( self.proj_out, outs) ) + # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention + s = [ *x.shape[:-1], self.num_heads, -1] + qs, ks, vs = split( self.proj_heads( x).reshape(s).transpose( 2, 1), 3, dim=-1) + qs, ks = self.ln_q( qs), self.ln_k( ks) + + # correct ordering of tensors with seq dimension second but last is critical + with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : + outs = self.att( qs, ks, vs).transpose( 2, 1) + + return x_in + self.dropout( self.proj_out( outs.flatten( -2, -1)) ) - return x_in + outs, atts + ######################################### + def attention( self, q, k, v) : + scaling = 1. / torch.sqrt( torch.tensor(q.shape[-1])) + return torch.matmul( self.softmax( scaling * self.score( q, k)), v) + + ######################################### + def score( self, q, k) : + return torch.matmul( q, torch.transpose( k, -2, -1)) #################################################################################################### class MultiCrossAttentionHead(torch.nn.Module): def __init__(self, dim_embed, num_heads, num_heads_other, dropout_rate = 0., with_qk_lnorm =False, - grad_checkpointing = False, with_attention=False): + grad_checkpointing = False, with_attention=False, with_flash=True): super(MultiCrossAttentionHead, self).__init__() self.num_heads = num_heads @@ -169,22 +108,24 @@ def __init__(self, dim_embed, num_heads, num_heads_other, dropout_rate = 0., wit else : self.lnorm_other = torch.nn.Identity() - # self attention heads - self.heads = torch.nn.ModuleList() - for n in range( num_heads) : - self.heads.append( AttentionHead( [dim_embed, self.dim_head_proj], - with_qk_lnorm = with_qk_lnorm, with_attention=with_attention)) + self.proj_heads = torch.nn.Linear( dim_embed, num_heads*3*self.dim_head_proj, bias = False) + + self.proj_heads_o_q = torch.nn.Linear(dim_embed, num_heads_other*self.dim_head_proj, bias=False) + self.proj_heads_o_kv= torch.nn.Linear(dim_embed,num_heads_other*2*self.dim_head_proj,bias=False) - # cross attention heads - self.heads_other = torch.nn.ModuleList() - for n in range( num_heads_other) : - self.heads_other.append( AttentionHead( [dim_embed, self.dim_head_proj], - with_qk_lnorm = with_qk_lnorm, with_attention=with_attention)) + self.ln_q = torch.nn.LayerNorm( self.dim_head_proj) + self.ln_k = torch.nn.LayerNorm( self.dim_head_proj) # proj_out is done is axial attention head so do not repeat it self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) self.dropout = torch.nn.Dropout( p=dropout_rate) + if with_flash : + self.att = torch.nn.functional.scaled_dot_product_attention + else : + self.att = self.attention + self.softmax = torch.nn.Softmax(dim=-1) + self.checkpoint = identity if grad_checkpointing : self.checkpoint = checkpoint_wrapper @@ -192,27 +133,32 @@ def __init__(self, dim_embed, num_heads, num_heads_other, dropout_rate = 0., wit def forward( self, x, x_other) : x_in = x - x = self.lnorm( x) - x_other = self.lnorm_other( x_other) - - # output tensor where output of heads is linearly concatenated - outs, atts = [], [] - - # self attention - for head in self.heads : - y, att = self.checkpoint( head, x) - outs.append( y) - atts.append( att) + x, x_other = self.lnorm( x), self.lnorm_other( x_other) - # cross attention - for head in self.heads_other : - y, att = self.checkpoint( head, x, x_other) - outs.append( y) - atts.append( att) + # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention + x = x.flatten( 1, -2) + s = [ *x.shape[:-1], self.num_heads, -1] + qs, ks, vs = torch.tensor_split( self.proj_heads( x).reshape(s).transpose( 2, 1), 3, dim=-1) + # qs, ks = self.ln_q( qs), self.ln_k( ks) + + s = [ *x.shape[:-1], self.num_heads_other, -1] + qs_o = self.proj_heads_o_q( x).reshape(s).transpose( 2, 1) + x_o = x_other.flatten( 1, -2) + s = [ *x_o.shape[:-1], self.num_heads_other, -1] + ks_o, vs_o = torch.tensor_split( self.proj_heads_o_kv(x_o).reshape(s).transpose( 2, 1),2,dim=-1) + # qs_o, ks_o = self.ln_q( qs_o), self.ln_k( ks_o) + + # correct ordering of tensors with seq dimension second but last is critical + with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : + s = list(x_in.shape) + s[-1] = -1 + outs_self = self.att( qs, ks, vs).transpose( 2, 1).flatten( -2, -1).reshape(s) + outs_other = self.att( qs_o, ks_o, vs_o).transpose( 2, 1).flatten( -2, -1).reshape(s) + outs = torch.cat( [outs_self, outs_other], -1) - outs = torch.cat( outs, -1) outs = self.dropout( self.checkpoint( self.proj_out, outs) ) + atts = [] return x_in + outs, atts #################################################################################################### @@ -220,44 +166,57 @@ def forward( self, x, x_other) : class MultiInterAttentionHead(torch.nn.Module): ##################################### - def __init__( self, num_heads_self, num_heads_coupling, dims_embed, with_lnorm = True, - dropout_rate = 0., with_qk_lnorm = False, grad_checkpointing = False, - with_attention=False) : + def __init__( self, num_heads_self, num_fields_other, num_heads_coupling_per_field, dims_embed, + with_lnorm = True, dropout_rate = 0., with_qk_lnorm = False, + grad_checkpointing = False, with_attention=False, with_flash=True) : '''Multi-head attention with multiple interacting fields coupled through attention.''' super(MultiInterAttentionHead, self).__init__() + self.num_heads_self = num_heads_self + self.num_heads_coupling_per_field = num_heads_coupling_per_field self.num_fields = len(dims_embed) - # self.coupling_mode = coupling_mode - # assert 0 == (dims_embed[0] % (num_heads_self + num_heads_coupling)) - # self.dim_head_proj = int(dims_embed[0] / (num_heads_self + num_heads_coupling)) self.dim_head_proj = int(dims_embed[0] / num_heads_self) # layer norms for all fields self.lnorms = torch.nn.ModuleList() + ln = torch.nn.LayerNorm if with_lnorm else torch.nn.Identity for ifield in range( self.num_fields) : - if with_lnorm : - self.lnorms.append( torch.nn.LayerNorm( dims_embed[ifield], elementwise_affine=False)) - else : - self.lnorms.append( torch.nn.Identity()) - - # self attention heads - self.heads_self = torch.nn.ModuleList() - for n in range( num_heads_self) : - self.heads_self.append( AttentionHead( [dims_embed[0], self.dim_head_proj], - dims_embed[0], with_qk_lnorm, with_attention=with_attention )) + self.lnorms.append( ln( dims_embed[ifield], elementwise_affine=False)) + + # self-attention + + nnc = num_fields_other * num_heads_coupling_per_field + self.proj_out = torch.nn.Linear( self.dim_head_proj * (num_heads_self + nnc), + dims_embed[0], bias = False) + self.dropout = torch.nn.Dropout( p=dropout_rate) if dropout_rate > 0. else torch.nn.Identity() + + nhs = num_heads_self + self.proj_heads = torch.nn.Linear( dims_embed[0], nhs*3*self.dim_head_proj, bias = False) - # coupling attention heads - self.heads_coupling = torch.nn.ModuleList() - for ifield in range( num_heads_coupling) : - arg1 = [dims_embed[ifield+1], self.dim_head_proj] - self.heads_coupling.append( AttentionHead( arg1, dims_embed[0], with_qk_lnorm, - with_attention=with_attention )) - - # self.proj_out = torch.nn.Linear( dims_embed[0], dims_embed[0], bias = False) - self.proj_out = torch.nn.Linear( self.dim_head_proj * (num_heads_self + num_heads_coupling), dims_embed[0], bias = False) - self.dropout = torch.nn.Dropout( p=dropout_rate) + # cross-attention + + nhc_dim = num_heads_coupling_per_field * self.dim_head_proj + self.proj_heads_other = torch.nn.ModuleList() + # queries from primary source/target field + self.proj_heads_other.append( torch.nn.Linear( dims_embed[0], nhc_dim*num_fields_other, + bias=False)) + # keys, values for other fields + for i in range(num_fields_other) : + self.proj_heads_other.append( torch.nn.Linear( dims_embed[i+1], 2*nhc_dim, bias=False)) + + ln = torch.nn.LayerNorm if with_qk_lnorm else torch.nn.Identity + self.ln_qk = (ln( self.dim_head_proj, elementwise_affine=False), + ln( self.dim_head_proj, elementwise_affine=False)) + nfo = num_fields_other + self.ln_k_other = [ln(self.dim_head_proj,elementwise_affine=False) for _ in range(nfo)] + + if with_flash : + self.att = torch.nn.functional.scaled_dot_product_attention + else : + self.att = self.attention + self.softmax = torch.nn.Softmax(dim=-1) self.checkpoint = identity if grad_checkpointing : @@ -266,32 +225,59 @@ def __init__( self, num_heads_self, num_heads_coupling, dims_embed, with_lnorm = ##################################### def forward( self, *args) : '''Evaluate block''' - - x_in = args[0] + + x_in, atts = args[0], [] # layer norm for each field fields_lnormed = [] for ifield, field in enumerate( args) : fields_lnormed.append( self.lnorms[ifield](field) ) + + # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention + # collapse three space and time dimensions for dense space-time attention + field_proj = self.proj_heads( fields_lnormed[0].flatten(1,-2)) + s = [ *field_proj.shape[:-1], self.num_heads_self, -1 ] + qs, ks, vs = torch.tensor_split( field_proj.reshape(s).transpose(-3,-2), 3, dim=-1) + qs, ks = self.ln_qk[0]( qs), self.ln_qk[1]( ks) + + if len(fields_lnormed) > 1 : + + field_proj = self.proj_heads_other[0]( fields_lnormed[0].flatten(1,-2)) + s = [ *field_proj.shape[:-1], len(fields_lnormed)-1, self.num_heads_coupling_per_field, -1 ] + qs_other = field_proj.reshape(s).permute( [-3, 0, -2, 1, -1]) - # output tensor where output of heads is linearly concatenated - outs, atts = [], [] - - # self attention - for head in self.heads_self : - y, att = self.checkpoint( head, fields_lnormed[0], fields_lnormed[0]) - outs.append( y) - atts.append( att) + ofields_projs = [] + for i,f in enumerate(fields_lnormed[1:]) : + f_proj = self.proj_heads_other[i+1](f.flatten(1,-2)) + s = [ *f_proj.shape[:-1], self.num_heads_coupling_per_field, -1 ] + ks_o, vs_o = torch.tensor_split( f_proj.reshape(s).transpose(-3,-2), 2, dim=-1) + ofields_projs += [ (self.ln_k_other[i]( ks_o), vs_o) ] + + # correct ordering of tensors with seq dimension second but last is critical + with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : + + # self-attention + s = list(fields_lnormed[0].shape) + outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s) - # inter attention - for ifield, head in enumerate(self.heads_coupling) : - y, att = self.checkpoint( head, fields_lnormed[0], fields_lnormed[ifield+1]) - outs.append( y) - atts.append( att) + # cross-attention + if len(fields_lnormed) > 1 : + s[-1] = -1 + outs_other = [self.att( q, k, v).transpose( -3, -2).flatten( -2, -1).reshape(s) + for (q,(k,v)) in zip(qs_other,ofields_projs)] + outs = torch.cat( [outs, *outs_other], -1) - outs = torch.cat( outs, -1) - outs = self.dropout( self.checkpoint( self.proj_out, outs) ) + # code.interact( local=locals()) + outs = self.dropout( self.proj_out( outs)) return x_in + outs, atts - + ######################################### + def attention( self, q, k, v) : + scaling = 1. / torch.sqrt( torch.tensor(q.shape[-1])) + return torch.matmul( self.softmax( scaling * self.score( q, k)), v) + + ######################################### + def score( self, q, k) : + # code.interact( local=locals()) + return torch.matmul( q, torch.transpose( k, -2, -1)) diff --git a/atmorep/transformer/transformer_encoder.py b/atmorep/transformer/transformer_encoder.py index 36581f9..054615e 100644 --- a/atmorep/transformer/transformer_encoder.py +++ b/atmorep/transformer/transformer_encoder.py @@ -57,7 +57,7 @@ def create( self) : self.mlps = torch.nn.ModuleList() for il in range( cf.encoder_num_layers) : - nhc = cf.coupling_num_heads_per_field * len( field_info[1][2]) + nhc = cf.coupling_num_heads_per_field # * len( field_info[1][2]) # nhs = cf.encoder_num_heads - nhc nhs = cf.encoder_num_heads # number of tokens @@ -78,8 +78,9 @@ def create( self) : # attention heads if 'dense' == cf.encoder_att_type : - head = MultiInterAttentionHead( nhs, nhc, dims_embed, with_ln, dor, cf.with_qk_lnorm, - cf.grad_checkpointing, with_attention=cf.attention ) + head = MultiInterAttentionHead( nhs, len(field_info[1][2]), nhc, dims_embed, with_ln, dor, + cf.with_qk_lnorm, cf.grad_checkpointing, + with_attention=cf.attention ) elif 'axial' in cf.encoder_att_type : par = True if 'parallel' in cf.encoder_att_type else False head = MultiFieldAxialAttention( [3,2,1], dims_embed, nhs, nhc, par, dor) From 92b097b58923ac79d1f8330a4b67a1ff66204a19 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 23 May 2024 10:07:18 +0200 Subject: [PATCH 34/66] Fixed bug with multi-year training data ranges. --- atmorep/core/train.py | 17 +++++++++-------- atmorep/datasets/multifield_data_sampler.py | 7 +++++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index e5220d8..7da484e 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -194,7 +194,7 @@ def train() : cf.fields_targets = [] - cf.years_train = [2021] # list( range( 1980, 2018)) + cf.years_train = list( range( 1979, 2018)) cf.years_test = [2021] #[2018] cf.month = None cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]] @@ -286,15 +286,16 @@ def train() : # # in steps x lat_degrees x lon_degrees cf.n_size = [36, 1*9*6, 1.*9*12] - # # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' - # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' - # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' - # # - # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' + # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' + # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' + cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' + # + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' - # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' # in steps x lat_degrees x lon_degrees - # cf.n_size = [36, 0.25*9*6, 0.25*9*12] + cf.n_size = [36, 0.25*9*6, 0.25*9*12] if cf.with_wandb and 0 == cf.par_rank : cf.write_json( wandb) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 2021a0a..7839170 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -22,6 +22,8 @@ import time import logging +import code + # from atmorep.datasets.normalizer_global import NormalizerGlobal # from atmorep.datasets.normalizer_local import NormalizerLocal from atmorep.datasets.normalizer import normalize @@ -93,9 +95,10 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, self.times = pd.DatetimeIndex( self.ds['time']) idxs_years = self.times.year == years[0] for year in years[1:] : - idxs_years = np.logical_and( idxs_years, self.times.year == year) + idxs_years = np.logical_or( idxs_years, self.times.year == year) self.idxs_years = np.where( idxs_years)[0] - logging.getLogger('atmorep').info( f'Dataset size for years {years}: {len(self.idxs_years)}.') + # logging.getLogger('atmorep').info( f'Dataset size for years {years}: {len(self.idxs_years)}.') + print( f'Dataset size for years {years}: {len(self.idxs_years)}.', flush=True) ################################################### def shuffle( self) : From 6831ffc73373beab87a047f1f274fb2be90ec48e Mon Sep 17 00:00:00 2001 From: iluise Date: Mon, 27 May 2024 16:45:24 +0200 Subject: [PATCH 35/66] delete unused files --- atmorep/datasets/dynamic_field_level.py | 242 ------------------------ atmorep/datasets/file_io.py | 85 --------- atmorep/datasets/normalizer_global.py | 42 ---- atmorep/datasets/normalizer_local.py | 82 -------- atmorep/datasets/static_field.py | 223 ---------------------- 5 files changed, 674 deletions(-) delete mode 100644 atmorep/datasets/dynamic_field_level.py delete mode 100644 atmorep/datasets/file_io.py delete mode 100644 atmorep/datasets/normalizer_global.py delete mode 100644 atmorep/datasets/normalizer_local.py delete mode 100644 atmorep/datasets/static_field.py diff --git a/atmorep/datasets/dynamic_field_level.py b/atmorep/datasets/dynamic_field_level.py deleted file mode 100644 index 0a55d99..0000000 --- a/atmorep/datasets/dynamic_field_level.py +++ /dev/null @@ -1,242 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import pathlib -import numpy as np -import math -import os, sys -import time -import itertools -import gc -import code -# code.interact(local=locals()) - -import atmorep.utils.utils as utils -from atmorep.utils.utils import shape_to_str -from atmorep.utils.utils import days_until_month_in_year -from atmorep.utils.utils import days_in_month - -from atmorep.datasets.data_loader import DataLoader -from atmorep.datasets.normalizer_global import NormalizerGlobal -from atmorep.datasets.normalizer_local import NormalizerLocal - -class DynamicFieldLevel() : - - ################################################### - def __init__( self, file_path, years_data, field_info, - batch_size, data_type = 'era5', - file_shape = [-1, 721, 1440], file_geo_range = [[-90.,90.], [0.,360.]], - num_tokens = [3, 9, 9], token_size = [1, 9, 9], - level_type = 'pl', vl = 975, time_sampling = 1, - smoothing = 0, file_format = 'grib', corr_type = 'local', - log_transform_data = False ) : - ''' - Data set for single dynamic field at a single vertical level - ''' - - self.years_data = years_data - self.field_info = field_info - self.file_path = file_path - self.file_shape = file_shape - self.file_format = file_format - self.level_type = level_type - self.vl = vl - self.time_sampling = time_sampling - self.smoothing = smoothing - self.corr_type = corr_type - self.log_transform_data = log_transform_data - - self.years_months = [] - - # work internally with mathematical latitude coordinates in [0,180] - self.file_geo_range = [ -np.array(file_geo_range[0]) + 90. , np.array(file_geo_range[1]) ] - # enforce that georange is North to South - self.geo_range_flipped = False - if self.file_geo_range[0][0] > self.file_geo_range[0][1] : - self.file_geo_range[0] = np.flip( self.file_geo_range[0]) - self.geo_range_flipped = True - self.is_global = 0. == self.file_geo_range[0][0] and 0. == self.file_geo_range[1][0] \ - and 180. == self.file_geo_range[0][1] and 360. == self.file_geo_range[1][1] - - # resolution - # TODO: non-uniform resolution in latitude and longitude - self.res = (file_geo_range[1][1] - file_geo_range[1][0]) - self.res /= file_shape[2] if self.is_global else (file_shape[2]-1) - - self.batch_size = batch_size - self.num_tokens = torch.tensor( num_tokens, dtype=torch.int) - rem1 = (num_tokens[1]*token_size[1]) % 2 - rem2 = (num_tokens[2]*token_size[2]) % 2 - t1 = num_tokens[1]*token_size[1] - t2 = num_tokens[2]*token_size[2] - self.grid_delta = [ [int((t1+rem1)/2), int(t1/2)], [int((t2+rem2)/2), int(t2/2)] ] - assert( num_tokens[1] < file_shape[1]) - assert( num_tokens[2] < file_shape[2]) - self.tok_size = token_size - - self.data_field = None - - if self.corr_type == 'global' : - self.normalizer = NormalizerGlobal( field_info, vl, self.file_shape, data_type) - else : - self.normalizer = NormalizerLocal( field_info, vl, self.file_shape, data_type) - - self.loader = DataLoader( self.file_path, self.file_shape, data_type, field_info, - file_format = self.file_format, level_type = self.level_type, - smoothing = self.smoothing, log_transform=self.log_transform_data) - - ################################################### - def load_data( self, years_months, idxs_perm, batch_size = None) : - - self.idxs_perm = idxs_perm.copy() - - # nothing to be loaded - if set(years_months) in set(self.years_months): - return - - self.years_months = years_months - - if batch_size : - self.batch_size = batch_size - loader = self.loader - - self.files_offset_days = [] - for year, month in self.years_months : - self.files_offset_days.append( days_until_month_in_year( year, month) ) - - # load data - # self.data_field is a list of lists of torch tensors - # [i] : year/month - # [i][j] : field per year/month - # [i][j] : len_data_per_month x num_tokens_lat x num_tokens_lon x token_size x token_size - # this ensures coherence in the data access - del self.data_field - gc.collect() - self.data_field = loader.get_single_field( self.years_months, self.field_info[0], - self.level_type, self.vl, [-1, -1], - [self.num_tokens[0] * self.tok_size[0], 0, - self.time_sampling]) - - # apply normalization and log-transform for each year-month data - for j in range( len(self.data_field) ) : - - if self.corr_type == 'local' : - coords = [ np.linspace( 0., 180., num=180*4+1, endpoint=True), - np.linspace( 0., 360., num=360*4, endpoint=False) ] - else : - coords = None - - (year, month) = self.years_months[j] - self.data_field[j] = self.normalizer.normalize( year, month, self.data_field[j], coords) - - # basics statistics - print( 'INFO:: data stats {} : {} / {}'.format( self.field_info[0], - self.data_field[j].mean(), - self.data_field[j].std()) ) - - ############################################### - def __getitem__( self, bidx) : - - tn = self.grid_delta - num_tokens = self.num_tokens - tok_size = self.tok_size - tnt = self.num_tokens[0] * self.tok_size[0] - cat = torch.cat - geor = self.file_geo_range - - idx = bidx * self.batch_size - - # physical fields - patch_s = [nt*ts for nt,ts in zip(self.num_tokens,self.tok_size)] - x = torch.zeros( self.batch_size, patch_s[0], patch_s[1], patch_s[2] ) - cids = torch.zeros( self.batch_size, num_tokens.prod(), 8) - - # offset from previous month to be able to sample all time slices in current one - offset_t = int(num_tokens[0] * tok_size[0]) - # 721 etc have grid points at the beginning and end which leads to incorrect results in places - file_shape = np.array(self.file_shape) - file_shape = file_shape-1 if not self.is_global else np.array(self.file_shape)-np.array([0,1,0]) - - # for all items in batch - for jj in range( self.batch_size) : - - i_ym = int(self.idxs_perm[idx][0]) - # perform a deep copy to not overwrite cid for other fields - cid = np.array( self.idxs_perm[idx][1:]).copy() - cid_orig = cid.copy() - - # map to grid coordinates (first map to normalized [0,1] coords and then to grid coords) - cid[2] = np.mod( cid[2], 360.) if self.is_global else cid[2] - assert cid[1] >= geor[0][0] and cid[1] <= geor[0][1], 'invalid latitude for geo_range' - cid[1] = ( (cid[1] - geor[0][0]) / (geor[0][1] - geor[0][0]) ) * file_shape[1] - cid[2] = ( ((cid[2]) - geor[1][0]) / (geor[1][1] - geor[1][0]) ) * file_shape[2] - assert cid[1] >= 0 and cid[1] < self.file_shape[1] - assert cid[2] >= 0 and cid[2] < self.file_shape[2] - - # alignment when parent field has different resolution than this field - cid = np.round( cid).astype( np.int64) - - ran_t = list( range( cid[0]-tnt+1 + offset_t, cid[0]+1 + offset_t)) - if any(np.array(ran_t) >= self.data_field[i_ym].shape[0]) : - print( '{} : {} :: {}'.format( self.field_info[0], self.years_months[i_ym], ran_t )) - - # periodic boundary conditions around equator - ran_lon = np.array( list( range( cid[2]-tn[1][0], cid[2]+tn[1][1]))) - if self.is_global : - ran_lon = np.mod( ran_lon, self.file_shape[2]) - else : - # sanity check for indices for files with local window - # this should be controlled by georange_sampling for sampling - assert all( ran_lon >= 0) and all( ran_lon < self.file_shape[2]) - - ran_lat = np.array( list( range( cid[1]-tn[0][0], cid[1]+tn[0][1]))) - assert all( ran_lat >= 0) and all( ran_lat < self.file_shape[1]) - - # current data - # if self.geo_range_flipped : - # print( '{} : {} / {}'.format( self.field_info[0], ran_lat, ran_lon) ) - if np.max(ran_t) >= self.data_field[i_ym].shape[0] : - print( 'WARNING: {} : {} :: {}'.format( self.field_info[0], ran_t, self.years_months[i_ym]) ) - x[jj] = np.take( np.take( self.data_field[i_ym][ran_t], ran_lat, 1), ran_lon, 2) - - # set per token information - assert self.time_sampling == 1 - ran_tt = np.flip( np.arange( cid[0], cid[0]-tnt, -tok_size[0])) - years = self.years_months[i_ym][0] * np.ones( ran_tt.shape) - days_in_year = self.files_offset_days[i_ym] + (ran_tt / 24.) - # wrap year around - mask = days_in_year < 0 - years[ mask ] -= 1 - days_in_year[ mask ] += 365 - hours = np.mod( ran_tt, 24) - lats = ran_lat[int(tok_size[1]/2)::tok_size[1]] * self.res + self.file_geo_range[0][0] - lons = ran_lon[int(tok_size[2]/2)::tok_size[2]] * self.res + self.file_geo_range[1][0] - stencil = torch.tensor(list(itertools.product(lats,lons))) - tstencil = torch.tensor( [ [y, d, h, self.vl] for y,d,h in zip( years, days_in_year, hours)], - dtype=torch.float) - txlist = list( itertools.product( tstencil, stencil)) - cids[jj,:,:6] = torch.cat( [torch.cat(tx).unsqueeze(0) for tx in txlist], 0) - cids[jj,:,6] = self.vl - cids[jj,:,7] = self.res - - idx += 1 - - return (x, cids) - - ################################################### - def __len__(self): - return int(self.idxs_perm.shape[0] / self.batch_size) diff --git a/atmorep/datasets/file_io.py b/atmorep/datasets/file_io.py deleted file mode 100644 index 1974044..0000000 --- a/atmorep/datasets/file_io.py +++ /dev/null @@ -1,85 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import numpy as np -import xarray as xr - -#################################################################################################### - -def netcdf_file_loader(fname, field, time_padding = [0,0,1], days_in_month = 0, static=False) : - - ds = xr.open_dataset(fname, engine='netcdf4')[field] - - if not static: - # TODO: test that only time_padding[0] *or* time_padding[1] != 0 - if time_padding[0] != 0 : - ds = ds[-time_padding[0]*time_padding[2] : ] - elif time_padding[1] != 0 : - ds = ds[ : time_padding[1]*time_padding[2]] - if time_padding[2] > 1 : - ds = ds[::time_padding[2]] - - x = torch.from_numpy(np.array( ds, dtype=np.float32)) - ds.close() - - return x - -#################################################################################################### - -def grib_file_loader(fname, field, time_padding = [0,0,1], days_in_month = 0, static=False) : - - ds = xr.open_dataset(fname, engine='cfgrib', - backend_kwargs={'time_dims':('valid_time','indexing_time')})[field] - - # work-around for bug in download where for every month 31 days have been downloaded - if days_in_month > 0 : - ds = ds[:24*days_in_month] - - if not static: - # TODO: test that only time_padding[0] *or* time_padding[1] != 0 - if time_padding[0] != 0 : - ds = ds[-time_padding[0]*time_padding[2] : ] - elif time_padding[1] != 0 : - ds = ds[ : time_padding[1]*time_padding[2]] - if time_padding[2] > 1 : - ds = ds[::time_padding[2]] - - x = torch.from_numpy(np.array(ds, dtype=np.float32)) - ds.close() - - # assume grib files are clean and NaNs are introduced through handling of "missing values" - if np.isnan( x).any() : - x_shape = x.shape - x = x.flatten() - x[np.argwhere( np.isnan( x))] = 9999. - x = np.reshape( x, x_shape) - - return x - -#################################################################################################### - -def bin_file_loader( fname, field, time_padding = [0,0,1], static=False, file_shape = (-1, 721, 1440)) : - - ds = np.fromfile(fname, dtype=np.float32) - print("INFO:: reshaping binary file into {}".format(file_shape)) - if not static: - ds = np.reshape(ds, file_shape) - ds = ds[time_padding[0]:(ds.shape[0] - time_padding[0])] - else: - ds = np.reshape(ds, file_shape[1:]) - x = torch.from_numpy(ds) - return x diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py deleted file mode 100644 index 3837a02..0000000 --- a/atmorep/datasets/normalizer_global.py +++ /dev/null @@ -1,42 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import numpy as np -import atmorep.config.config as config - -class NormalizerGlobal() : - - def __init__(self, field_info, vlevel, data_type = 'era5', level_type = 'ml') : - - fname_base = '{}/{}/normalization/{}/global_normalization_mean_var_{}_{}{}.bin' - - fn = field_info[0] - corr_fname = fname_base.format( str(config.path_data), data_type, fn, fn, level_type, vlevel) - self.corr_data = np.fromfile(corr_fname, dtype=np.float32).reshape( (-1, 4)) - - - def normalize( self, year, month, data, coords = None) : - corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), - self.corr_data[:,1] == float(month))) , 2:].flatten() - data_temp = (data - corr_data_ym[0]) / corr_data_ym[1] - return (data - corr_data_ym[0]) / corr_data_ym[1] - - def denormalize( self, year, month, data, coords = None) : - corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), - self.corr_data[:,1] == float(month))) , 2:].flatten() - data_temp = (data * corr_data_ym[1]) + corr_data_ym[0] - return (data * corr_data_ym[1]) + corr_data_ym[0] - \ No newline at end of file diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py deleted file mode 100644 index 92e58c5..0000000 --- a/atmorep/datasets/normalizer_local.py +++ /dev/null @@ -1,82 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import code -import numpy as np -import xarray as xr -import atmorep.config.config as config - -class NormalizerLocal() : - - # def __init__(self, field_info, vlevel, range_lat, range_lon, res, is_global, data_type= 'era5', level_type = 'ml') : - def __init__(self, field_info, vlevel, data_type = 'era5', level_type = 'ml') : - - fname_base = './data/{}/normalization/{}/normalization_mean_var_{}_y{}_m{:02d}_{}{}.bin' - self.year_base = config.datasets[data_type]['extent'][0][0] - self.year_last = config.datasets[data_type]['extent'][0][1] - lat_min, lat_max = config.datasets[data_type]['extent'][1] - lat_min, lat_max = 90. - lat_min, 90. - lat_max - lat_min, lat_max = (lat_min, lat_max) if lat_min < lat_max else (lat_max, lat_min) - lon_min, lon_max = config.datasets[data_type]['extent'][2] - res = config.datasets[data_type]['resolution'][1] - is_global = config.datasets[data_type]['is_global'] - - self.corr_data = [ ] - for year in range( self.year_base, self.year_last+1) : - for month in range( 1, 12+1) : - corr_fname = fname_base.format( data_type, field_info[0], field_info[0], - year, month, level_type, vlevel) - ns_lat = int( (lat_max-lat_min) / res + 1) - ns_lon = int( (lon_max-lon_min) / res + (0 if is_global else 1) ) - - x = np.fromfile( corr_fname, dtype=np.float32).reshape( (ns_lat, ns_lon, 2)) - # TODO, TODO, TODO: remove once recomputed - if 'cerra' == data_type : - x[:,:,0] = 340. - x[:,:,1] = 600. - x = xr.DataArray( x, [ ('lat', np.linspace( lat_min, lat_max, num=ns_lat, endpoint=True)), - ('lon', np.linspace( lon_min, lon_max, num=ns_lon, endpoint=False)), - ('data', ['mean', 'var']) ]) - self.corr_data.append( x) - - - def normalize( self, year, month, data, coords) : - corr_data_ym = self.corr_data[ (year - self.year_base) * 12 + (month-1) ] - mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values - var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values - if (var == 0.).all() : - print( f'var == 0 :: ym : {year} / {month}') - assert False - - if len(data.shape) > 2 : - for i in range( data.shape[0]) : - data[i] = (data[i] - mean) / var - else : - data = (data - mean) / var - return data - - def denormalize( self, year, month, data, coords) : - corr_data_ym = self.corr_data[ (year - self.year_base) * 12 + (month-1) ] - mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values - var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values - if len(data.shape) > 2 : - for i in range( data.shape[0]) : - data[i] = (data[i] * var) + mean - else : - data = (data * var) + mean - return data - - \ No newline at end of file diff --git a/atmorep/datasets/static_field.py b/atmorep/datasets/static_field.py deleted file mode 100644 index f636757..0000000 --- a/atmorep/datasets/static_field.py +++ /dev/null @@ -1,223 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import pathlib -import numpy as np -import math -import os, sys -import time -import itertools - -import atmorep.utils.utils as utils -from atmorep.utils.utils import shape_to_str -from atmorep.utils.utils import days_until_month_in_year -from atmorep.utils.utils import days_in_month -from atmorep.utils.utils import tokenize - -from atmorep.datasets.data_loader import DataLoader - - -class StaticField() : - - ################################################### - def __init__( self, file_path, field_info, batch_size, data_type = 'reanalysis', - file_shape = (-1, 720, 1440), file_geo_range = [[90.,-90.], [0.,360.]], - num_tokens = [3, 9, 9], token_size = [1, 9, 9], - smoothing = 0, file_format = 'grib', corr_type = 'global') : - ''' - Data set for single dynamic field at a single vertical level - ''' - - self.field_info = field_info - self.file_path = file_path - self.file_shape = file_shape - self.file_format = file_format - self.smoothing = smoothing - self.corr_type = corr_type - - # # work internally with mathematical latitude coordinates in [0,180] - # self.is_global = np.abs(file_geo_range[0][0])==90. and file_geo_range[1][0]==0. \ - # and np.abs(file_geo_range[0][0])==90. and file_geo_range[1][1]==360. - # self.file_geo_range = [ -np.array(file_geo_range[0]) + 90. , file_geo_range[1] ] - # self.file_geo_range[0] = np.flip( self.file_geo_range[0]) \ - # if self.file_geo_range[0][0] > self.file_geo_range[0][1] else self.file_geo_range[0] - - # work internally with mathematical latitude coordinates in [0,180] - self.file_geo_range = [ -np.array(file_geo_range[0]) + 90. , np.array(file_geo_range[1]) ] - # enforce that georange is North to South - self.geo_range_flipped = False - if self.file_geo_range[0][0] > self.file_geo_range[0][1] : - self.file_geo_range[0] = np.flip( self.file_geo_range[0]) - self.geo_range_flipped = True - print( 'Flipped georange') - print( '{} :: geo_range : {}'.format( field_info[0], self.file_geo_range) ) - self.is_global = 0. == self.file_geo_range[0][0] and 0. == self.file_geo_range[1][0] \ - and 180. == self.file_geo_range[0][1] and 360. == self.file_geo_range[1][1] - print( '{} :: is_global : {}'.format( field_info[0], self.is_global) ) - - self.batch_size = batch_size - self.num_tokens = torch.tensor( num_tokens, dtype=torch.int) - rem1 = (num_tokens[1]*token_size[1]) % 2 - rem2 = (num_tokens[2]*token_size[2]) % 2 - t1 = num_tokens[1]*token_size[1] - t2 = num_tokens[2]*token_size[2] - self.grid_delta = [ [int((t1+rem1)/2), int(t1/2)], [int((t2+rem2)/2), int(t2/2)] ] - assert( num_tokens[1] < file_shape[1]) - assert( num_tokens[2] < file_shape[2]) - self.tok_size = token_size - #assert( file_shape[1] % token_size[1] == 0) - #assert( file_shape[2] % token_size[2] == 0) - - # resolution - # TODO: non-uniform resolution in latitude and longitude - self.res = (file_geo_range[1][1] - file_geo_range[1][0]) - self.res /= file_shape[2] if self.is_global else (file_shape[2]-1) - - self.data_field = None - - self.loader = DataLoader( self.file_path, self.file_shape, data_type, - file_format = self.file_format, - smoothing = self.smoothing ) - - ################################################### - def load_data( self, years_months, idxs_perm, batch_size = None) : - - self.idxs_perm = idxs_perm - loader = self.loader - - if batch_size : - self.batch_size = batch_size - - # load data - self.data_field = loader.get_static_field( self.field_info[0], [-1, -1]) - - # # corrections: - self.correction_field = loader.get_correction_static_field( self.field_info[0], self.corr_type ) - - mean = self.correction_field[0] - std = self.correction_field[1] - - self.data_field = (self.data_field - mean) / std - - if self.geo_range_flipped : - self.data_field = torch.flip( self.data_field, [0]) - - # # basics statistics - # print( 'INFO:: data stats {} : {} / {}'.format( self.field_info[0], - # self.data_field.mean(), - # self.data_field.std()) ) - - ################################################### - def set_data( self, date_pos ) : - ''' - date_pos = np.array( [ [year, month, day, hour, lat, lon], ...] ) - - lat \in [-90,90] = [90N, 90S] - - (year,month) pairs should be a limited number since all data for these is loaded - ''' - - # extract required years and months - years_months_all = np.array( [ [it[0], it[1]] for it in date_pos ], dtype=np.int64) - self.years_months = list( zip( np.unique(years_months_all[:,0]), - np.unique( years_months_all[:,1] ))) - - # load data and corrections - self.load_data() - - # generate all the data - self.idxs_perm = np.zeros( (date_pos.shape[0], 4), dtype=np.int64) - for idx, item in enumerate( date_pos) : - - assert item[2] >= 1 and item[2] <= 31 - assert item[3] >= 0 and item[3] < int(24 / self.time_sampling) - assert item[4] >= -90. and item[4] <= 90. - - # find year - for i_ym, ym in enumerate( self.years_months) : - if ym[0] == item[0] and ym[1] == item[1] : - break - - it = (item[2] - 1.) * 24. + item[3] + self.tok_size[0] - idx_lat = int( (item[4] + 90.) * 720. / 180.) - idx_lon = int( (item[5] % 360) * 1440. / 360.) - - self.idxs_perm[idx] = np.array( [i_ym, it, idx_lat, idx_lon], dtype=np.int64) - - ############################################### - def __getitem__( self, bidx) : - - tn = self.grid_delta - num_tokens = self.num_tokens - tok_size = self.tok_size - geor = self.file_geo_range - - idx = bidx * self.batch_size - - # physical fields - patch_s = [nt*ts for nt,ts in zip(self.num_tokens,self.tok_size)] - x = torch.zeros( self.batch_size, 1, patch_s[1], patch_s[2] ) - cids = torch.zeros( self.batch_size, num_tokens.prod(), 8) - - # 721 etc have grid points at the beginning and end which leads to incorrect results in places - file_shape = np.array(self.file_shape) - file_shape = file_shape-1 if not self.is_global else np.array(self.file_shape)-np.array([0,1,0]) - - # for all items in batch - for jj in range( self.batch_size) : - - # perform a deep copy to not overwrite cid for other fields - cid = np.array( self.idxs_perm[idx][1:]).copy() - - # map to grid coordinates (first map to normalized [0,1] coords and then to grid coords) - cid[2] = np.mod( cid[2], 360.) if self.is_global else cid[2] - assert cid[1] >= geor[0][0] and cid[1] <= geor[0][1], 'invalid latitude for geo_range' - cid[1] = ( (cid[1] - geor[0][0]) / (geor[0][1] - geor[0][0]) ) * file_shape[1] - cid[2] = ( ((cid[2]) - geor[1][0]) / (geor[1][1] - geor[1][0]) ) * file_shape[2] - assert cid[1] >= 0 and cid[1] < self.file_shape[1] - assert cid[2] >= 0 and cid[2] < self.file_shape[2] - - # alignment when parent field has different resolution than this field - cid = np.round( cid).astype( np.int64) - - # periodic boundary conditions around equator - ran_lon = np.array( list( range( cid[2]-tn[1][0], cid[2]+tn[1][1]))) - if self.is_global : - ran_lon = np.mod( ran_lon, self.file_shape[2]) - else : - # sanity check for indices for files with local window - # this should be controlled by georange_sampling for sampling - assert any( ran_lon >= 0) or any( ran_lon < self.file_shape[2]) - - ran_lat = np.array( list( range( cid[1]-tn[0][0], cid[1]+tn[0][1]))) - assert any( ran_lat >= 0) or any( ran_lat < self.file_shape[1]) - - # current data - x[jj,0] = np.take( np.take( self.data_field, ran_lat, 0), ran_lon, 1) - - # set per token information - lats = ran_lat[int(tok_size[1]/2)::tok_size[1]] * self.res + self.file_geo_range[0][0] - lons = ran_lon[int(tok_size[2]/2)::tok_size[2]] * self.res + self.file_geo_range[1][0] - stencil = torch.tensor(list(itertools.product(lats,lons))) - cids[jj,:,4:6] = stencil - cids[jj,:,7] = self.res - - idx += 1 - - return (x, cids) - - ################################################### - def __len__(self): - return int(self.idxs_perm.shape[0] / self.batch_size) From e3a7faa6d1c2adbf13748d2da7b0054513ca9498 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 6 Jun 2024 19:57:59 +0200 Subject: [PATCH 36/66] first implem of weight_translate + time bug fix --- atmorep/config/config.py | 1 - atmorep/core/atmorep_model.py | 80 +++++++++++++++- atmorep/core/evaluate.py | 10 +- atmorep/core/evaluator.py | 22 ++++- atmorep/core/train.py | 97 +++----------------- atmorep/core/trainer.py | 7 +- atmorep/datasets/multifield_data_sampler.py | 6 +- atmorep/transformer/transformer_attention.py | 3 +- 8 files changed, 123 insertions(+), 103 deletions(-) diff --git a/atmorep/config/config.py b/atmorep/config/config.py index 6c6e20d..6c0fe01 100644 --- a/atmorep/config/config.py +++ b/atmorep/config/config.py @@ -3,7 +3,6 @@ fpath = os.path.dirname(os.path.realpath(__file__)) -path_data = Path(fpath, '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_final.zarr') path_models = Path( fpath, '../../models/') path_results = Path( fpath, '../../results') path_plots = Path( fpath, '../results/plots/') diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 2e3ff5f..79a9227 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -73,7 +73,7 @@ def set_global( self, mode : NetMode, times, batch_size = -1, num_loader_workers cf = self.net.cf if batch_size < 0 : batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test - + print(times) dataset = self.dataset_train if mode == NetMode.train else self.dataset_test dataset.set_global( times, batch_size, cf.token_overlap) @@ -138,11 +138,11 @@ def mode( self, mode : NetMode) : if mode == NetMode.train : self.data_loader_iter = iter(self.data_loader_train) - # self.data_loader_iter = iter(self.dataset_train) + #self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : self.data_loader_iter = iter(self.data_loader_test) - # self.data_loader_iter = iter(self.dataset_test) + #self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -375,6 +375,74 @@ def load_block( self, field_info, block_name, block ) : print( 'Loaded {} for {} from id = {} (ignoring/missing {} elements).'.format( block_name, field_info[0], field_info[1][4][0], len(mkeys) ) ) + ################################################### + def translate_weights(self, mloaded, mkeys, ukeys): + cf = self.cf + + #encoder: + for layer in range(cf.encoder_num_layers) : + qs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] + ks = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] + vs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] + mw = torch.cat( [*qs, *ks, *vs]) + #print(qs[0][:3,:3]) + mloaded[f'encoders.0.heads.{layer}.proj_heads.weight'] = mw + for head in range(cf.encoder_num_heads): + del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight'] + + #cross attention + if f'encoders.0.heads.{layer}.heads_other.0.proj_qs.weight' in ukeys: + qs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] + ks = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] + vs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] + mw = torch.cat( [*qs, *ks, *vs]) + + for i in range(cf.encoder_num_heads): + del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] + + else: + mw = torch.tensor(np.zeros([0,2048])) + + mloaded[f'encoders.0.heads.{layer}.proj_heads_other.0.weight'] = mw + + #decoder + for iblock in range(0, 19, 2) : + print(iblock) + qs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] for i in range(8)] + ks = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight'] for i in range(8)] + vs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight'] for i in range(8)] + mw = torch.cat( [*qs, *ks, *vs]) + mloaded[f'decoders.0.blocks.{iblock}.proj_heads.weight'] = mw + + qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight'] for i in range(8)] + ks = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] for i in range(8)] + vs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] for i in range(8)] + + mw = torch.cat( [*ks, *vs]) + mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_q.weight'] = torch.cat([*qs]) + mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_kv.weight'] = mw + + #self.num_samples_validate + mloaded[f'decoders.0.blocks.{iblock}.ln_q.weight'] = torch.tensor(np.ones([128])) + mloaded[f'decoders.0.blocks.{iblock}.ln_k.weight'] = torch.tensor(np.ones([128])) + mloaded[f'decoders.0.blocks.{iblock}.ln_q.bias'] = torch.tensor(np.ones([128])) + mloaded[f'decoders.0.blocks.{iblock}.ln_k.bias'] = torch.tensor(np.ones([128])) + + for i in range(8): + del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] + + + return mloaded + ################################################### @staticmethod def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : @@ -386,7 +454,11 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : model = AtmoRep( cf).create( devices, load_pretrained=False) mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) ) - mkeys, _ = model.load_state_dict( mloaded, False ) + mkeys, ukeys = model.load_state_dict( mloaded, False ) + mloaded = model.translate_weights(mloaded, mkeys, ukeys) + mkeys, ukeys = model.load_state_dict( mloaded, False ) + + # breakpoint() if len(mkeys) > 0 : print( f'Loaded AtmoRep: ignoring {len(mkeys)} elements: {mkeys}') diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 3528b19..13dbde4 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -54,10 +54,10 @@ # BERT forecast with patching to obtain global forecast mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], - 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 2, - 'attention' : False } + 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], + 'token_overlap' : [0, 0], + 'forecast_num_tokens' : 2, + 'attention' : False } now = time.time() Evaluator.evaluate( mode, model_id, options) - print("time", time.time() - now) \ No newline at end of file + print("time", time.time() - now) diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 1b2f74b..8d93a8c 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -54,6 +54,11 @@ def parse_args( cf, args) : ############################################## @staticmethod def run( cf, model_id, model_epoch, devices) : + + cf.batch_size = cf.batch_size_max + cf.batch_size_validation = cf.batch_size_max + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' + cf.with_mixed_precision = True # set/over-write options as desired evaluator = Evaluator.load( cf, model_id, model_epoch, devices) @@ -82,7 +87,6 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : cf.par_rank = par_rank cf.par_size = par_size # overwrite old config - cf.file_path = str(config.path_data) cf.attention = False setup_wandb( cf.with_wandb, cf, par_rank, '', mode='offline') if 0 == cf.par_rank : @@ -132,14 +136,26 @@ def forecast( cf, model_id, model_epoch, devices, args = {}) : ############################################## @staticmethod def global_forecast( cf, model_id, model_epoch, devices, args = {}) : - + cf.BERT_strategy = 'global_forecast' cf.batch_size_test = 24 cf.num_loader_workers = 1 cf.log_test_num_ranks = 1 - cf.batch_size_start = 14 + + if not hasattr(cf, 'file_path'): + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' + + if not hasattr(cf, 'batch_size'): + cf.batch_size = 14 + if not hasattr(cf, 'batch_size_validation'): + cf.batch_size_validation = 1 #64 + if not hasattr(cf, 'batch_size_delta'): + cf.batch_size_delta = 8 if not hasattr(cf, 'num_samples_validate'): cf.num_samples_validate = 196 + #if not hasattr(cf,'with_mixed_precision'): + cf.with_mixed_precision = True + Evaluator.parse_args( cf, args) dates = args['dates'] diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 7da484e..e078dce 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -114,83 +114,11 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'vorticity', [ 1, 2048, [ ], 0 ], - [ 96, 105, 114, 123, 137 ], - # [ 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - - # cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ] ] - - # cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - - # cf.fields = [ [ 'velocity_z', [ 1, 1024, [ ], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - - # cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - - # [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ] - - # cf.fields = [ [ 'total_precip', [ 1, 2048, [ ], 0 ], - # [ 0 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - - # cf.fields = [ [ 'geopotential', [ 1, 1024, [], 0 ], - # [ 0 ], - # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields_prediction = [ ['geopotential', 1.] ] - - #cf.fields_prediction = [ [cf.fields[0][0], 1.0] ] - - # cf.fields = [ [ 'velocity_u', [ 1, 1024, ['velocity_v'], 0], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], - # [ 'velocity_v', [ 1, 1024, ['velocity_u'], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields_prediction = [ [cf.fields[0][0], 0.5], [cf.fields[1][0], 0.5] ] - - - # cf.fields = [ [ 'velocity_u', [ 1, 1024, ['velocity_v'], 0], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ], - # [ 'velocity_v', [ 1, 1024, ['velocity_u'], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ], - # [ 'total_precip', [ 1, 1024, ['velocity_u', 'velocity_v'], 0 ], - # [ 0 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] - # cf.fields_prediction = [ [cf.fields[0][0], 0.33], [cf.fields[1][0], 0.33], - # [cf.fields[2][0], 0.33]] - - cf.fields = [ [ 'vorticity', [ 1, 2048, ['divergence', 'temperature'], 0 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], - [ 'velocity_z', [ 1, 1536, ['vorticity', 'divergence'], 0 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], - [ 'divergence', [ 1, 2048, ['vorticity', 'temperature'], 1 ], + cf.fields = [ [ 'temperature',[ 1, 512, [ ], 3 ], [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], - [ 'specific_humidity', [ 1, 2048, ['vorticity', 'divergence', 'velocity_z'], 2 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], - [ 'temperature', [ 1, 1024, ['vorticity', 'divergence'], 3 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ], - [ 'total_precip', [ 1, 1536, ['vorticity', 'divergence', 'specific_humidity'], 3 ], - [ 0 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.2, 0.05] ] ] - cf.fields_prediction = [['vorticity', 0.25], ['velocity_z', 0.1], - ['divergence', 0.25], ['specific_humidity', 0.15], - ['temperature', 0.15], ['total_precip', 0.1] ] + [12, 2, 4], [3, 27, 27], [0.25, 0.9, 0.2, 0.05], 'local' ] ] + + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] cf.fields_targets = [] @@ -211,8 +139,7 @@ def train() : cf.torch_seed = torch.initial_seed() # training params cf.batch_size_validation = 1 #64 - cf.batch_size = 16 # 4 # 32 - cf.batch_size_delta = 8 + cf.batch_size = 8 #16 #4 # 32 cf.num_epochs = 128 # additional infos @@ -227,12 +154,12 @@ def train() : cf.learnable_mask = False cf.with_qk_lnorm = False # encoder - cf.encoder_num_layers = 4 + cf.encoder_num_layers = 6 #10 #4 cf.encoder_num_heads = 16 cf.encoder_num_mlp_layers = 2 cf.encoder_att_type = 'dense' # decoder - cf.decoder_num_layers = 4 + cf.decoder_num_layers = 6 #10 #4 cf.decoder_num_heads = 16 cf.decoder_num_mlp_layers = 2 cf.decoder_self_att = False @@ -280,17 +207,17 @@ def train() : cf.with_mixed_precision = True cf.num_samples_per_epoch = 4096 cf.num_samples_validate = 128 - cf.num_loader_workers = 8 + cf.num_loader_workers = 6 - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' + #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' # # in steps x lat_degrees x lon_degrees - cf.n_size = [36, 1*9*6, 1.*9*12] + #cf.n_size = [36, 1*9*6, 1.*9*12] # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' - cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' + #cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' # - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' + #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index feac614..ae06272 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -499,6 +499,11 @@ def loss( self, preds, batch_idx = 0) : crps_loss = torch.mean( CRPS( target, pred[0], pred[1])) losses['crps'].append( crps_loss) + #cosine weighted RMSE loss + if 'cos_weighted_mse' in self.cf.losses : + cw_mse_loss = torch.mean( CosW_MSELoss( pred[0], target = target)) + losses['cos_weighted_mse'].append( cw_mse_loss) + loss = torch.tensor( 0., device=self.device_out) tot_weight = torch.tensor( 0., device=self.device_out) for key in losses : @@ -540,7 +545,7 @@ def prepare_batch( self, xin) : cf = self.cf devs = self.devices - + # unpack loader output # xin[0] since BERT does not have targets (sources, token_infos, targets, fields_tokens_masked_idx_list) = xin[0] diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 7839170..8d77248 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -140,10 +140,10 @@ def __iter__(self): sources_infos, source_idxs = [], [] i_bidx = self.idxs_perm_t[bidx] - idxs_t = list(np.arange( bidx - n_size[0]*ts, bidx, ts, dtype=np.int64)) + idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) data_tt_sfc = self.ds['data_sfc'].oindex[idxs_t] data_tt = self.ds['data'].oindex[idxs_t] - + for sidx in range(self.batch_size) : # i_bidx = self.idxs_perm_t[bidx] @@ -253,7 +253,7 @@ def set_data( self, times_pos, batch_size = None) : ################################################### def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : ''' generate patch/token positions for global grid ''' - + print("inside set_global") token_overlap = np.array( token_overlap).astype(np.int64) # assumed that sanity checking that field data is consistent has been done diff --git a/atmorep/transformer/transformer_attention.py b/atmorep/transformer/transformer_attention.py index 59f7b1b..abf3583 100644 --- a/atmorep/transformer/transformer_attention.py +++ b/atmorep/transformer/transformer_attention.py @@ -235,11 +235,12 @@ def forward( self, *args) : # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention # collapse three space and time dimensions for dense space-time attention + #proj_heads: torch.Size([3, 16, 128, 2048]) field_proj = self.proj_heads( fields_lnormed[0].flatten(1,-2)) s = [ *field_proj.shape[:-1], self.num_heads_self, -1 ] qs, ks, vs = torch.tensor_split( field_proj.reshape(s).transpose(-3,-2), 3, dim=-1) + breakpoint() qs, ks = self.ln_qk[0]( qs), self.ln_qk[1]( ks) - if len(fields_lnormed) > 1 : field_proj = self.proj_heads_other[0]( fields_lnormed[0].flatten(1,-2)) From ecee5b50799b859c707c5f6e7405992ba425cfee Mon Sep 17 00:00:00 2001 From: iluise Date: Fri, 7 Jun 2024 18:00:37 +0200 Subject: [PATCH 37/66] validate MultiCrossAttentionHead --- atmorep/core/atmorep_model.py | 47 +++++++++++++------- atmorep/transformer/transformer_attention.py | 11 +++-- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 79a9227..dd5962a 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -381,10 +381,20 @@ def translate_weights(self, mloaded, mkeys, ukeys): #encoder: for layer in range(cf.encoder_num_layers) : - qs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] - ks = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] - vs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] - mw = torch.cat( [*qs, *ks, *vs]) + # qs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] + # ks = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] + # vs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] + # mw = torch.cat( [*qs, *ks, *vs]) + #torch.Size([3, 16, 128, 2048]) + # att = temp.reshape([16, 3, 128, 2048]) + # att.shape + # (Pdb) (Pdb) torch.Size([16, 3, 128, 2048]) + # att1 = temp1.reshape([16, 3, 128, 2048]) + # mw.shape + # (Pdb) (Pdb) torch.Size([6144, 2048]) + # att_mw = mw.reshape([3, 16, 128, 2048]) + mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]]) + #breakpoint() #print(qs[0][:3,:3]) mloaded[f'encoders.0.heads.{layer}.proj_heads.weight'] = mw for head in range(cf.encoder_num_heads): @@ -394,11 +404,12 @@ def translate_weights(self, mloaded, mkeys, ukeys): #cross attention if f'encoders.0.heads.{layer}.heads_other.0.proj_qs.weight' in ukeys: - qs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] - ks = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] - vs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] - mw = torch.cat( [*qs, *ks, *vs]) - + # qs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] + # ks = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] + # vs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] + # mw = torch.cat( [*qs, *ks, *vs]) + mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]]) + for i in range(cf.encoder_num_heads): del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] @@ -412,17 +423,19 @@ def translate_weights(self, mloaded, mkeys, ukeys): #decoder for iblock in range(0, 19, 2) : print(iblock) - qs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] for i in range(8)] - ks = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight'] for i in range(8)] - vs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight'] for i in range(8)] - mw = torch.cat( [*qs, *ks, *vs]) + # qs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] for i in range(8)] + # ks = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight'] for i in range(8)] + # vs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight'] for i in range(8)] + # mw = torch.cat( [*qs, *ks, *vs]) + mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads.{head}.proj_{k}.weight'] for head in range(8) for k in ["qs", "ks", "vs"]]) mloaded[f'decoders.0.blocks.{iblock}.proj_heads.weight'] = mw - qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight'] for i in range(8)] - ks = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] for i in range(8)] - vs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] for i in range(8)] + qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_qs.weight'] for head in range(8)] + # ks = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] for i in range(8)] + # vs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] for i in range(8)] + # mw = torch.cat( [*ks, *vs]) + mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_{k}.weight'] for head in range(8) for k in ["ks", "vs"]]) - mw = torch.cat( [*ks, *vs]) mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_q.weight'] = torch.cat([*qs]) mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_kv.weight'] = mw diff --git a/atmorep/transformer/transformer_attention.py b/atmorep/transformer/transformer_attention.py index abf3583..4356394 100644 --- a/atmorep/transformer/transformer_attention.py +++ b/atmorep/transformer/transformer_attention.py @@ -139,14 +139,14 @@ def forward( self, x, x_other) : x = x.flatten( 1, -2) s = [ *x.shape[:-1], self.num_heads, -1] qs, ks, vs = torch.tensor_split( self.proj_heads( x).reshape(s).transpose( 2, 1), 3, dim=-1) - # qs, ks = self.ln_q( qs), self.ln_k( ks) + qs, ks = self.ln_q( qs), self.ln_k( ks) s = [ *x.shape[:-1], self.num_heads_other, -1] qs_o = self.proj_heads_o_q( x).reshape(s).transpose( 2, 1) x_o = x_other.flatten( 1, -2) s = [ *x_o.shape[:-1], self.num_heads_other, -1] ks_o, vs_o = torch.tensor_split( self.proj_heads_o_kv(x_o).reshape(s).transpose( 2, 1),2,dim=-1) - # qs_o, ks_o = self.ln_q( qs_o), self.ln_k( ks_o) + qs_o, ks_o = self.ln_q( qs_o), self.ln_k( ks_o) # correct ordering of tensors with seq dimension second but last is critical with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : @@ -155,9 +155,8 @@ def forward( self, x, x_other) : outs_self = self.att( qs, ks, vs).transpose( 2, 1).flatten( -2, -1).reshape(s) outs_other = self.att( qs_o, ks_o, vs_o).transpose( 2, 1).flatten( -2, -1).reshape(s) outs = torch.cat( [outs_self, outs_other], -1) - + outs = self.dropout( self.checkpoint( self.proj_out, outs) ) - atts = [] return x_in + outs, atts @@ -235,11 +234,11 @@ def forward( self, *args) : # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention # collapse three space and time dimensions for dense space-time attention - #proj_heads: torch.Size([3, 16, 128, 2048]) + #proj_heads: torch.Size([16, 3, 128, 2048]) field_proj = self.proj_heads( fields_lnormed[0].flatten(1,-2)) s = [ *field_proj.shape[:-1], self.num_heads_self, -1 ] qs, ks, vs = torch.tensor_split( field_proj.reshape(s).transpose(-3,-2), 3, dim=-1) - breakpoint() + #breakpoint() qs, ks = self.ln_qk[0]( qs), self.ln_qk[1]( ks) if len(fields_lnormed) > 1 : From 705f03e1cb79a7746089bfdd07c50f341cf0a8e9 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Mon, 10 Jun 2024 18:29:21 +0200 Subject: [PATCH 38/66] prepare full config example for Epicure --- atmorep/core/atmorep_model.py | 41 ++++++++----------------------- atmorep/core/train.py | 33 +++++++++++++++++++++---- slurm_atmorep.sh | 45 +++++++++++++++++++++++++++++++++++ slurm_atmorep_evaluate.sh | 45 +++++++++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 36 deletions(-) create mode 100755 slurm_atmorep.sh create mode 100755 slurm_atmorep_evaluate.sh diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index dd5962a..6663ec1 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -313,6 +313,7 @@ def create( self, devices, load_pretrained=True) : device = self.devices[0] if len(field_info[1]) > 3 : assert field_info[1][3] < 4, 'Only single node model parallelism supported' + print(devices, field_info[1][3]) assert field_info[1][3] < len(devices), 'Per field device id larger than max devices' device = self.devices[ field_info[1][3] ] # set device @@ -377,26 +378,18 @@ def load_block( self, field_info, block_name, block ) : ################################################### def translate_weights(self, mloaded, mkeys, ukeys): + ''' + Function used for backward compatibility + ''' cf = self.cf #encoder: for layer in range(cf.encoder_num_layers) : - # qs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] - # ks = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] - # vs = [mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] - # mw = torch.cat( [*qs, *ks, *vs]) - #torch.Size([3, 16, 128, 2048]) - # att = temp.reshape([16, 3, 128, 2048]) - # att.shape - # (Pdb) (Pdb) torch.Size([16, 3, 128, 2048]) - # att1 = temp1.reshape([16, 3, 128, 2048]) - # mw.shape - # (Pdb) (Pdb) torch.Size([6144, 2048]) - # att_mw = mw.reshape([3, 16, 128, 2048]) + + #shape([16, 3, 128, 2048]) mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]]) - #breakpoint() - #print(qs[0][:3,:3]) mloaded[f'encoders.0.heads.{layer}.proj_heads.weight'] = mw + for head in range(cf.encoder_num_heads): del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] @@ -404,10 +397,6 @@ def translate_weights(self, mloaded, mkeys, ukeys): #cross attention if f'encoders.0.heads.{layer}.heads_other.0.proj_qs.weight' in ukeys: - # qs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] for head in range(cf.encoder_num_heads)] - # ks = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] for head in range(cf.encoder_num_heads)] - # vs = [mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] for head in range(cf.encoder_num_heads)] - # mw = torch.cat( [*qs, *ks, *vs]) mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]]) for i in range(cf.encoder_num_heads): @@ -422,18 +411,10 @@ def translate_weights(self, mloaded, mkeys, ukeys): #decoder for iblock in range(0, 19, 2) : - print(iblock) - # qs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] for i in range(8)] - # ks = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight'] for i in range(8)] - # vs = [mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight'] for i in range(8)] - # mw = torch.cat( [*qs, *ks, *vs]) mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads.{head}.proj_{k}.weight'] for head in range(8) for k in ["qs", "ks", "vs"]]) mloaded[f'decoders.0.blocks.{iblock}.proj_heads.weight'] = mw qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_qs.weight'] for head in range(8)] - # ks = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] for i in range(8)] - # vs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] for i in range(8)] - # mw = torch.cat( [*ks, *vs]) mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_{k}.weight'] for head in range(8) for k in ["ks", "vs"]]) mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_q.weight'] = torch.cat([*qs]) @@ -453,7 +434,6 @@ def translate_weights(self, mloaded, mkeys, ukeys): del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] - return mloaded ################################################### @@ -468,10 +448,9 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : model = AtmoRep( cf).create( devices, load_pretrained=False) mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) ) mkeys, ukeys = model.load_state_dict( mloaded, False ) - mloaded = model.translate_weights(mloaded, mkeys, ukeys) - mkeys, ukeys = model.load_state_dict( mloaded, False ) - - # breakpoint() + if (f'encoders.0.heads.0.proj_heads.weight') in mkeys: + mloaded = model.translate_weights(mloaded, mkeys, ukeys) + mkeys, ukeys = model.load_state_dict( mloaded, False ) if len(mkeys) > 0 : print( f'Loaded AtmoRep: ignoring {len(mkeys)} elements: {mkeys}') diff --git a/atmorep/core/train.py b/atmorep/core/train.py index e078dce..3f5d179 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -114,15 +114,38 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'temperature',[ 1, 512, [ ], 3 ], + # cf.fields = [ [ 'temperature',[ 1, 512, [ ], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 2, 4], [3, 27, 27], [0.25, 0.9, 0.2, 0.05], 'local' ] ] + + # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + + cf.fields = [ [ 'velocity_u', [ 1, 2048, ['velocity_v', 'temperature'], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + [ 'velocity_v', [ 1, 2048, ['velocity_u', 'temperature'], 1 ], [ 96, 105, 114, 123, 137 ], - [12, 2, 4], [3, 27, 27], [0.25, 0.9, 0.2, 0.05], 'local' ] ] + [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + [ 'specific_humidity', [ 1, 2048, ['velocity_u', 'velocity_v', 'temperature'], 2 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + [ 'velocity_z', [ 1, 1024, ['velocity_u', 'velocity_v', 'temperature'], 3 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'global' ], + [ 'temperature', [ 1, 1024, ['velocity_u', 'velocity_v', 'specific_humidity'], 3 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + ['total_precip', [1, 1536, ['velocity_u', 'velocity_v', 'velocity_z', 'specific_humidity'], 3], + [0], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05]] ] - cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + cf.fields_prediction = [['velocity_u', 0.225], ['velocity_v', 0.225], + ['specific_humidity', 0.15], ['velocity_z', 0.1], ['temperature', 0.2], + ['total_precip', 0.1] ] cf.fields_targets = [] - cf.years_train = list( range( 1979, 2018)) + cf.years_train = list( range( 2010, 2021)) cf.years_test = [2021] #[2018] cf.month = None cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]] @@ -219,7 +242,7 @@ def train() : # #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' + cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' # in steps x lat_degrees x lon_degrees cf.n_size = [36, 0.25*9*6, 0.25*9*12] diff --git a/slurm_atmorep.sh b/slurm_atmorep.sh new file mode 100755 index 0000000..c1cf9cf --- /dev/null +++ b/slurm_atmorep.sh @@ -0,0 +1,45 @@ +#!/bin/bash -x +#SBATCH --account=ehpc03 +#SBATCH --time=0-0:09:59 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=80 #####nodes * gpus/node * 20 +#SBATCH --gres=gpu:4 +#SBATCH --chdir=. +#SBATCH --qos=acc_ehpc +#SBATCH --output=logs/atmorep-%x.%j.out +#SBATCH --error=logs/atmorep-%x.%j.err + +# import modules +source pyenv/bin/activate + +export UCX_TLS="^cma" +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# so processes know who to talk to +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "MASTER_ADDR: $MASTER_ADDR" + +export NCCL_DEBUG=TRACE +echo "nccl_debug: $NCCL_DEBUG" + +# work-around for flipping links issue on JUWELS-BOOSTER +export NCCL_IB_TIMEOUT=250 +export UCX_RC_TIMEOUT=16s +export NCCL_IB_RETRY_CNT=50 + +echo "Starting job." +echo "Number of Nodes: $SLURM_JOB_NUM_NODES" +echo "Number of Tasks: $SLURM_NTASKS" +date + +CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_train_${SLURM_JOBID} +mkdir ${CONFIG_DIR} +cp ${SLURM_SUBMIT_DIR}/atmorep/core/train.py ${CONFIG_DIR} +echo "${CONFIG_DIR}/train.py" +srun --label --cpu-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/train.py > output/output_${SLURM_JOBID}.txt + +echo "Finished job." +date diff --git a/slurm_atmorep_evaluate.sh b/slurm_atmorep_evaluate.sh new file mode 100755 index 0000000..2994104 --- /dev/null +++ b/slurm_atmorep_evaluate.sh @@ -0,0 +1,45 @@ +#!/bin/bash -x +#SBATCH --account=ehpc03 +#SBATCH --time=0-0:09:59 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=80 #####nodes * gpus/node * 20 +#SBATCH --gres=gpu:4 +#SBATCH --chdir=. +#SBATCH --qos=acc_ehpc +#SBATCH --output=logs/atmorep-%x.%j.out +#SBATCH --error=logs/atmorep-%x.%j.err + +# import modules +source pyenv/bin/activate + +export UCX_TLS="^cma" +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# so processes know who to talk to +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "MASTER_ADDR: $MASTER_ADDR" + +export NCCL_DEBUG=TRACE +echo "nccl_debug: $NCCL_DEBUG" + +# work-around for flipping links issue on JUWELS-BOOSTER +export NCCL_IB_TIMEOUT=250 +export UCX_RC_TIMEOUT=16s +export NCCL_IB_RETRY_CNT=50 + +echo "Starting job." +echo "Number of Nodes: $SLURM_JOB_NUM_NODES" +echo "Number of Tasks: $SLURM_NTASKS" +date + +CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_eval_${SLURM_JOBID} +mkdir ${CONFIG_DIR} +cp ${SLURM_SUBMIT_DIR}/atmorep/core/evaluate.py ${CONFIG_DIR} +echo "${CONFIG_DIR}/train.py" +srun --label --cpu-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/evaluate.py > output/output_${SLURM_JOBID}.txt + +echo "Finished job." +date From f211001b764d589d72fda4e18890a4681998d69d Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Mon, 10 Jun 2024 18:43:53 +0200 Subject: [PATCH 39/66] comment out requirements for wheel installation --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 012c5b1..38e5bb4 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ description='AtmoRep', packages=find_packages(), # if packages are available in a native form fo the host system then these should be used - install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo'], + #install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo'], data_files=[('./output', []), ('./logs', []), ('./results',[])], ) From 4fcff359c504990c2b01af49597f94353473a023 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 12 Jun 2024 10:54:13 +0200 Subject: [PATCH 40/66] increase n workers evaluate --- atmorep/core/evaluate.py | 3 ++- atmorep/core/evaluator.py | 7 ++++--- atmorep/core/trainer.py | 15 +++----------- atmorep/datasets/multifield_data_sampler.py | 22 ++++++--------------- atmorep/utils/utils.py | 2 -- slurm_atmorep_evaluate.sh | 4 ++-- 6 files changed, 17 insertions(+), 36 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 13dbde4..3f385e5 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -19,7 +19,8 @@ if __name__ == '__main__': # models for individual fields - model_id = '4nvwbetz' # vorticity + model_id='m5vo5vqz' + #model_id = '4nvwbetz' # vorticity #model_id = 'oxpycr7w' # divergence # model_id = '1565pb1f' # specific_humidity #model_id = '3kdutwqb' # total precip diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 8d93a8c..f1fbc6e 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -91,8 +91,9 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : setup_wandb( cf.with_wandb, cf, par_rank, '', mode='offline') if 0 == cf.par_rank : print( 'Running Evaluate.evaluate with mode =', mode) - - cf.num_loader_workers = cf.loader_num_workers + + # if not hasattr( cf, 'num_loader_workers'): + cf.num_loader_workers = 12 #cf.loader_num_workers cf.rng_seed = None #backward compatibility @@ -139,7 +140,7 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : cf.BERT_strategy = 'global_forecast' cf.batch_size_test = 24 - cf.num_loader_workers = 1 + cf.num_loader_workers = 12 #1 cf.log_test_num_ranks = 1 if not hasattr(cf, 'file_path'): diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index ae06272..7278256 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -91,7 +91,6 @@ def __init__( self, cf, devices ) : ################################################### def create( self, load_embeds=True) : - net = AtmoRep( self.cf) self.model = AtmoRepData( net) @@ -100,13 +99,11 @@ def create( self, load_embeds=True) : # TODO: pass the properly to model / net self.model.net.encoder_to_decoder = self.encoder_to_decoder self.model.net.decoder_to_tail = self.decoder_to_tail - return self ################################################### @classmethod def load( Typename, cf, model_id, epoch, devices) : - trainer = Typename( cf, devices).create( load_embeds=False) trainer.model.net = trainer.model.net.load( model_id, devices, cf, epoch) @@ -116,7 +113,6 @@ def load( Typename, cf, model_id, epoch, devices) : str = 'Loaded model id = {}{}.'.format( model_id, f' at epoch = {epoch}' if epoch> -2 else '') print( str) - return trainer ################################################### @@ -365,7 +361,6 @@ def profile( self): ################################################### def validate( self, epoch, BERT_test_strategy = 'BERT'): - cf = self.cf BERT_strategy_train = cf.BERT_strategy cf.BERT_strategy = BERT_test_strategy @@ -383,7 +378,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory (sources, token_infos, targets, tmis_list) = batch_data[0] - + log_sources = ( [source.detach().clone().cpu() for source in sources ], [target.detach().clone().cpu() for target in targets ], tmis_list) @@ -391,11 +386,10 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): with torch.autocast(device_type='cuda',dtype=torch.float16,enabled=cf.with_mixed_precision): batch_data = self.prepare_batch( batch_data) preds, atts = self.model( batch_data) - loss = torch.tensor( 0.) ifield = 0 for pred, idx in zip( preds, self.fields_prediction_idx) : - + target = self.targets[idx] # hook for custom test loss self.test_loss( pred, target) @@ -405,7 +399,6 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): loss += cur_loss total_losses[ifield] += cur_loss ifield += 1 - total_loss += loss test_len += 1 @@ -413,10 +406,9 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): if cf.par_rank < cf.log_test_num_ranks : log_preds = [[p.detach().clone().cpu() for p in pred] for pred in preds] self.log_validate( epoch, it, log_sources, log_preds) - if cf.attention: self.log_attention( epoch, it, atts) - + # average over all nodes total_loss /= test_len * len(self.cf.fields_prediction) total_losses /= test_len @@ -440,7 +432,6 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): loss_dict[idx_name] = total_losses[i] print( 'validation loss for {} : {}'.format( field[0], total_losses[i] )) wandb.log( loss_dict) - batch_data = [] torch.cuda.empty_cache() diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 8d77248..c31a83c 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -123,7 +123,6 @@ def shuffle( self) : ################################################### def __iter__(self): - if self.with_shuffle : self.shuffle() @@ -135,7 +134,6 @@ def __iter__(self): iter_start, iter_end = self.worker_workset() for bidx in range( iter_start, iter_end) : - sources, token_infos = [[] for _ in self.fields], [[] for _ in self.fields] sources_infos, source_idxs = [], [] @@ -143,9 +141,7 @@ def __iter__(self): idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) data_tt_sfc = self.ds['data_sfc'].oindex[idxs_t] data_tt = self.ds['data'].oindex[idxs_t] - for sidx in range(self.batch_size) : - # i_bidx = self.idxs_perm_t[bidx] # idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) @@ -169,14 +165,12 @@ def __iter__(self): source_idxs += [ (idxs_t, lat_ran, lon_ran) ] # extract data - for ifield, field_info in enumerate(self.fields): - + for ifield, field_info in enumerate(self.fields): source_lvl, tok_info_lvl = [], [] tok_size = field_info[4] corr_type = 'global' if len(field_info) <= 6 else field_info[6] for ilevel, vl in enumerate(field_info[2]): - if vl == 0 : #surface level field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) data_t = data_tt_sfc[ :, field_idx ] @@ -188,14 +182,12 @@ def __iter__(self): source_data, tok_info = [], [] # extract data, normalize and tokenize cdata = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) - + normalizer = self.normalizers[ifield][ilevel] if corr_type != 'global': - normalizer = np.take( np.take( normalizer, lat_ran, -2), lon_ran, -1) - + normalizer = np.take( np.take( normalizer, lat_ran, -2), lon_ran, -1) cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) - source_data = tokenize( torch.from_numpy( cdata), tok_size ) - + source_data = tokenize( torch.from_numpy( cdata), tok_size ) # token_infos uses center of the token: *last* datetime and center in space dates = self.ds['time'][ idxs_t ].astype(datetime) cdates = dates[tok_size[0]-1::tok_size[0]] @@ -208,16 +200,14 @@ def __iter__(self): for (year, day, hour) in dates]] source_lvl += [ source_data ] - tok_info_lvl += [ torch.tensor(tok_info, dtype=torch.float32).flatten( 1, -2)] - + tok_info_lvl += [ torch.tensor(tok_info, dtype=torch.float32).flatten( 1, -2)] sources[ifield] += [ torch.stack(source_lvl, 0) ] token_infos[ifield] += [ torch.stack(tok_info_lvl, 0) ] - # concatenate batches + # concatenate batches sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources] token_infos = [torch.stack(tis_field).transpose(1,0) for tis_field in token_infos] sources = self.pre_batch( sources, token_infos ) - # TODO: implement (only required when prediction target comes from different data stream) targets, target_info = None, None target_idxs = None diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index 9fc4dd9..f2f1271 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -98,12 +98,10 @@ def write_json( self, wandb) : f.write( json_str) def load_json( self, wandb_id) : - if '/' in wandb_id : # assumed to be full path instead of just id fname = wandb_id else : fname = Path( config.path_models, 'id{}/model_id{}.json'.format( wandb_id, wandb_id)) - try : with open(fname, 'r') as f : json_str = f.readlines() diff --git a/slurm_atmorep_evaluate.sh b/slurm_atmorep_evaluate.sh index 2994104..017e07a 100755 --- a/slurm_atmorep_evaluate.sh +++ b/slurm_atmorep_evaluate.sh @@ -1,6 +1,6 @@ #!/bin/bash -x #SBATCH --account=ehpc03 -#SBATCH --time=0-0:09:59 +#SBATCH --time=0-02:30:00 #SBATCH --nodes=1 #SBATCH --cpus-per-task=80 #####nodes * gpus/node * 20 #SBATCH --gres=gpu:4 @@ -38,7 +38,7 @@ date CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_eval_${SLURM_JOBID} mkdir ${CONFIG_DIR} cp ${SLURM_SUBMIT_DIR}/atmorep/core/evaluate.py ${CONFIG_DIR} -echo "${CONFIG_DIR}/train.py" +echo "${CONFIG_DIR}/evaluate.py" srun --label --cpu-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/evaluate.py > output/output_${SLURM_JOBID}.txt echo "Finished job." From 88ee5c0446c610e01da11ff57ea56763e11484ba Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 12 Jun 2024 17:04:11 +0200 Subject: [PATCH 41/66] fix path in evaluate --- atmorep/core/evaluator.py | 13 ++++++++----- atmorep/datasets/multifield_data_sampler.py | 13 +++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index f1fbc6e..5c0d852 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -55,13 +55,17 @@ def parse_args( cf, args) : @staticmethod def run( cf, model_id, model_epoch, devices) : - cf.batch_size = cf.batch_size_max - cf.batch_size_validation = cf.batch_size_max - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' + if not hasattr(cf, 'batch_size'): + cf.batch_size = cf.batch_size_max + if not hasattr(cf, 'batch_size_validation'): + cf.batch_size_validation = cf.batch_size_max + + cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' cf.with_mixed_precision = True # set/over-write options as desired evaluator = Evaluator.load( cf, model_id, model_epoch, devices) + if 0 == cf.par_rank : cf.print() cf.write_json( wandb) @@ -117,7 +121,6 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : if not hasattr(cf, 'num_samples_validate'): cf.num_samples_validate = 128 #1472 Evaluator.parse_args( cf, args) - Evaluator.run( cf, model_id, model_epoch, devices) ############################################## @@ -144,7 +147,7 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : cf.log_test_num_ranks = 1 if not hasattr(cf, 'file_path'): - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y1979_2021_res025_chunk8.zarr' + cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' if not hasattr(cf, 'batch_size'): cf.batch_size = 14 diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index c31a83c..23e4059 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -20,9 +20,7 @@ import pandas as pd from datetime import datetime import time -import logging - -import code +import os # from atmorep.datasets.normalizer_global import NormalizerGlobal # from atmorep.datasets.normalizer_local import NormalizerLocal @@ -48,9 +46,11 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, self.num_samples = num_samples self.with_source_idxs = with_source_idxs self.with_shuffle = with_shuffle - self.pre_batch = pre_batch - self.ds = zarr.open( file_path) + + assert os.path.exists(file_path), f"File path {file_path} does not exist" + self.ds = zarr.open( file_path) + self.ds_global = self.ds.attrs['is_global'] self.lats = np.array( self.ds['lats']) @@ -75,7 +75,6 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, # lon: no change for periodic case if self.ds_global < 1.: self.range_lon += np.array([n_size[2]/2., -n_size[2]/2.]) - # data normalizers self.normalizers = [] for ifield, field_info in enumerate(fields) : @@ -90,7 +89,6 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, vl_idx = self.ds.attrs['levels'].index(vl) field_idx = self.ds.attrs['fields'].index( field_info[0]) self.normalizers[ifield] += [self.ds[f'normalization/{nf_name}'].oindex[ :, :, field_idx, vl_idx]] - # extract indices for selected years self.times = pd.DatetimeIndex( self.ds['time']) idxs_years = self.times.year == years[0] @@ -243,7 +241,6 @@ def set_data( self, times_pos, batch_size = None) : ################################################### def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : ''' generate patch/token positions for global grid ''' - print("inside set_global") token_overlap = np.array( token_overlap).astype(np.int64) # assumed that sanity checking that field data is consistent has been done From e286c7e7aa3236d1cef04963626aa0913ffabef0 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 3 Jul 2024 15:57:00 +0200 Subject: [PATCH 42/66] new wip weighted area --- atmorep/core/atmorep_model.py | 32 ++--- atmorep/core/evaluate.py | 50 ++++++-- atmorep/core/evaluator.py | 9 +- atmorep/core/train.py | 116 +++++++++++-------- atmorep/core/trainer.py | 101 +++++++++++----- atmorep/datasets/data_writer.py | 2 +- atmorep/datasets/multifield_data_sampler.py | 21 +++- atmorep/datasets/normalizer.py | 1 + atmorep/training/bert.py | 1 - atmorep/transformer/transformer_attention.py | 1 - atmorep/utils/utils.py | 19 +++ slurm_atmorep.sh | 11 +- slurm_atmorep_evaluate.sh | 8 +- 13 files changed, 254 insertions(+), 118 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 6663ec1..79834d4 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -73,7 +73,7 @@ def set_global( self, mode : NetMode, times, batch_size = -1, num_loader_workers cf = self.net.cf if batch_size < 0 : batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test - print(times) + print("dates:", times) dataset = self.dataset_train if mode == NetMode.train else self.dataset_test dataset.set_global( times, batch_size, cf.token_overlap) @@ -137,12 +137,12 @@ def normalizer( self, field, vl_idx, lats_idx, lons_idx ) : def mode( self, mode : NetMode) : if mode == NetMode.train : - self.data_loader_iter = iter(self.data_loader_train) - #self.data_loader_iter = iter(self.dataset_train) + #self.data_loader_iter = iter(self.data_loader_train) + self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : - self.data_loader_iter = iter(self.data_loader_test) - #self.data_loader_iter = iter(self.dataset_test) + #self.data_loader_iter = iter(self.data_loader_test) + self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -189,7 +189,7 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, cf.batch_size, pre_batch, cf.n_size, cf.num_samples_per_epoch, - with_shuffle = (cf.BERT_strategy != 'global_forecast') ) + with_shuffle = (cf.BERT_strategy != 'global_forecast'), with_source_idxs = True ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) @@ -405,10 +405,11 @@ def translate_weights(self, mloaded, mkeys, ukeys): del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] else: - mw = torch.tensor(np.zeros([0,2048])) - + dim_mw = self.encoders[0].heads[0].proj_heads_other[0].weight.shape + mw = torch.tensor(np.zeros(dim_mw)) + mloaded[f'encoders.0.heads.{layer}.proj_heads_other.0.weight'] = mw - + #decoder for iblock in range(0, 19, 2) : mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads.{head}.proj_{k}.weight'] for head in range(8) for k in ["qs", "ks", "vs"]]) @@ -421,10 +422,11 @@ def translate_weights(self, mloaded, mkeys, ukeys): mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_kv.weight'] = mw #self.num_samples_validate - mloaded[f'decoders.0.blocks.{iblock}.ln_q.weight'] = torch.tensor(np.ones([128])) - mloaded[f'decoders.0.blocks.{iblock}.ln_k.weight'] = torch.tensor(np.ones([128])) - mloaded[f'decoders.0.blocks.{iblock}.ln_q.bias'] = torch.tensor(np.ones([128])) - mloaded[f'decoders.0.blocks.{iblock}.ln_k.bias'] = torch.tensor(np.ones([128])) + decoder_dim = self.decoders[0].blocks[iblock].ln_q.weight.shape #128 + mloaded[f'decoders.0.blocks.{iblock}.ln_q.weight'] = torch.tensor(np.ones(decoder_dim)) + mloaded[f'decoders.0.blocks.{iblock}.ln_k.weight'] = torch.tensor(np.ones(decoder_dim)) + mloaded[f'decoders.0.blocks.{iblock}.ln_q.bias'] = torch.tensor(np.ones(decoder_dim)) + mloaded[f'decoders.0.blocks.{iblock}.ln_k.bias'] = torch.tensor(np.ones(decoder_dim)) for i in range(8): del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] @@ -433,7 +435,7 @@ def translate_weights(self, mloaded, mkeys, ukeys): del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight'] del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] - + return mloaded ################################################### @@ -458,7 +460,7 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : # TODO: remove, only for backward if model.embeds_token_info[0].weight.abs().max() == 0. : model.embeds_token_info = torch.nn.ModuleList() - + return model ################################################### diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 3f385e5..7eccb21 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -19,10 +19,38 @@ if __name__ == '__main__': # models for individual fields - model_id='m5vo5vqz' + + #2 nodes + #model_id='1b43bynq' + # model_id='p20z3ilu' + # model_id='10y42b1u' + #model_id='99tb5lcy' + # model_id='sn6h8wvq' + # model_id='085saknn' + + #1node train continue + #model_id='h7orvjna' + #model_id='ocpn87si' + + # 1node + #temperature + #model_id='66zlffty' + #model_id='fmzy4mxr' + model_id='ezi4shmb' + + #velocity_u + #model_id='hg8cy3c4' + #model_id='av0rp1mj' + #model_id='fc5o31h2' + + #specific_humidity + #model_id='w965qy0o' + #model_id='gpksrtrl' + #model_id='c6am1m3j' + #model_id = '4nvwbetz' # vorticity #model_id = 'oxpycr7w' # divergence - # model_id = '1565pb1f' # specific_humidity + #model_id = '1565pb1f' # specific_humidity #model_id = '3kdutwqb' # total precip #model_id = 'dys79lgw' # velocity_u #model_id = '22j6gysw' # velocity_v @@ -45,7 +73,7 @@ # BERT masked token model #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} - #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} + mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} # BERT forecast mode #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } @@ -54,11 +82,17 @@ #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'idx_time_mask': [5,6,7], 'attention' : False } # BERT forecast with patching to obtain global forecast - mode, options = 'global_forecast', { 'fields[0][2]' : [114], #[123, 137], #[105, 137], - 'dates' : [[2021, 1, 10, 18]], #[[2021, 2, 10, 12]], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 2, - 'attention' : False } + # mode, options = 'global_forecast', { #'fields[0][2]' : [114], #[123, 137], #[105, 137], + # #'dates' : [[2021, 1, 10, 18]], + # 'dates' : [ [2021, 1, 10, 18] + # # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], + # # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], + # # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], + # # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] + # ], + # 'token_overlap' : [0, 0], + # 'forecast_num_tokens' : 2, + # 'attention' : False } now = time.time() Evaluator.evaluate( mode, model_id, options) print("time", time.time() - now) diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 5c0d852..1b1d0a5 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -81,8 +81,10 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : with_ddp = False num_accs_per_task = 1 else : - num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) - devices = init_torch( num_accs_per_task) + num_accs_per_task = 1 #int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) + #devices = init_torch( num_accs_per_task) + devices = ['cuda'] + par_rank, par_size = setup_ddp( with_ddp) cf = Config().load_json( model_id) @@ -90,6 +92,7 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : cf.with_ddp = with_ddp cf.par_rank = par_rank cf.par_size = par_size + cf.losses = cf.losses + ['weighted_mse'] # overwrite old config cf.attention = False setup_wandb( cf.with_wandb, cf, par_rank, '', mode='offline') @@ -150,7 +153,7 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' if not hasattr(cf, 'batch_size'): - cf.batch_size = 14 + cf.batch_size = 196 #14 if not hasattr(cf, 'batch_size_validation'): cf.batch_size_validation = 1 #64 if not hasattr(cf, 'batch_size_delta'): diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 3f5d179..1fa8283 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -34,6 +34,7 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) + #device = ['cuda'] with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -44,20 +45,25 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : cf.par_size = par_size cf.optimizer_zero = False cf.attention = False + + cf.batch_size = 96 #16 #4 # 32 + cf.lr_max = 0.00005*3 + cf.num_samples_per_epoch = 4096*12 + cf.num_samples_validate = 128*12 # name has changed but ensure backward compatibility - if hasattr( cf, 'loader_num_workers') : - cf.num_loader_workers = cf.loader_num_workers - if not hasattr( cf, 'n_size'): - cf.n_size = [36, 0.25*9*6, 0.25*9*12] - if not hasattr(cf, 'num_samples_per_epoch'): - cf.num_samples_per_epoch = 1024 - if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 128 - if not hasattr(cf, 'with_mixed_precision'): - cf.with_mixed_precision = True + # if hasattr( cf, 'loader_num_workers') : + # cf.num_loader_workers = cf.loader_num_workers + # if not hasattr( cf, 'n_size'): + # cf.n_size = [36, 0.25*9*6, 0.25*9*12] + # if not hasattr(cf, 'num_samples_per_epoch'): + # cf.num_samples_per_epoch = 1024 + # if not hasattr(cf, 'num_samples_validate'): + # cf.num_samples_validate = 128 + # if not hasattr(cf, 'with_mixed_precision'): + # cf.with_mixed_precision = True - cf.years_train = [2021] # list( range( 1980, 2018)) - cf.years_test = [2021] #[2018] + # cf.years_train = [2021] # list( range( 1980, 2018)) + # cf.years_test = [2021] #[2018] # any parameter in cf can be overwritten when training is continued, e.g. we can increase the # masking rate @@ -84,9 +90,9 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : #################################################################################################### def train() : - num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) + num_accs_per_task = 1 #int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) - # device = ['cuda'] + #device = ['cuda'] with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -114,37 +120,46 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] + cf.fields = [ [ 'specific_humidity', [ 1, 1024, [ ], 0 ], + [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local'] ] + [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] + # cf.fields = [ [ 'temperature',[ 1, 512, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 2, 4], [3, 27, 27], [0.25, 0.9, 0.2, 0.05], 'local' ] ] + # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.2, 0.05] ] ] - # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + # cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] - cf.fields = [ [ 'velocity_u', [ 1, 2048, ['velocity_v', 'temperature'], 0 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - [ 'velocity_v', [ 1, 2048, ['velocity_u', 'temperature'], 1 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - [ 'specific_humidity', [ 1, 2048, ['velocity_u', 'velocity_v', 'temperature'], 2 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - [ 'velocity_z', [ 1, 1024, ['velocity_u', 'velocity_v', 'temperature'], 3 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'global' ], - [ 'temperature', [ 1, 1024, ['velocity_u', 'velocity_v', 'specific_humidity'], 3 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - ['total_precip', [1, 1536, ['velocity_u', 'velocity_v', 'velocity_z', 'specific_humidity'], 3], - [0], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05]] ] + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + + # cf.fields = [ [ 'velocity_u', [ 1, 2048, ['velocity_v', 'temperature'], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + # [ 'velocity_v', [ 1, 2048, ['velocity_u', 'temperature'], 1 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + # [ 'specific_humidity', [ 1, 2048, ['velocity_u', 'velocity_v', 'temperature'], 2 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + # [ 'velocity_z', [ 1, 1024, ['velocity_u', 'velocity_v', 'temperature'], 3 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'global' ], + # [ 'temperature', [ 1, 1024, ['velocity_u', 'velocity_v', 'specific_humidity'], 3 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], + # ['total_precip', [1, 1536, ['velocity_u', 'velocity_v', 'velocity_z', 'specific_humidity'], 3], + # [0], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05]] ] - cf.fields_prediction = [['velocity_u', 0.225], ['velocity_v', 0.225], - ['specific_humidity', 0.15], ['velocity_z', 0.1], ['temperature', 0.2], - ['total_precip', 0.1] ] + # cf.fields_prediction = [['velocity_u', 0.225], ['velocity_v', 0.225], + # ['specific_humidity', 0.15], ['velocity_z', 0.1], ['temperature', 0.2], + # ['total_precip', 0.1] ] cf.fields_targets = [] - + cf.years_train = list( range( 2010, 2021)) cf.years_test = [2021] #[2018] cf.month = None @@ -162,7 +177,7 @@ def train() : cf.torch_seed = torch.initial_seed() # training params cf.batch_size_validation = 1 #64 - cf.batch_size = 8 #16 #4 # 32 + cf.batch_size = 96 #16 #4 # 32 cf.num_epochs = 128 # additional infos @@ -194,26 +209,28 @@ def train() : cf.net_tail_num_layers = 0 # loss # supported: see Trainer for supported losses - # cf.losses = ['mse', 'stats'] - cf.losses = ['mse_ensemble', 'stats'] + #cf.losses = ['mse', 'stats'] + #cf.losses = ['mse_ensemble'] #, 'stats'] + cf.losses = ['weighted_mse'] # cf.losses = ['mse'] # cf.losses = ['stats'] # cf.losses = ['crps'] # training cf.optimizer_zero = False cf.lr_start = 5. * 10e-7 - cf.lr_max = 0.00005 + cf.lr_max = 0.00005*3 cf.lr_min = 0.00004 cf.weight_decay = 0.05 cf.lr_decay_rate = 1.025 cf.lr_start_epochs = 3 - cf.lat_sampling_weighted = True + cf.lat_sampling_weighted = False #True # BERT # strategies: 'BERT', 'forecast', 'temporal_interpolation', 'identity' - cf.BERT_strategy = 'BERT' + cf.BERT_strategy = 'forecast' #'BERT' cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution + cf.forecast_num_tokens = 2 #only when training in forecast mode # debug / output cf.log_test_num_ranks = 0 cf.save_grads = False @@ -228,10 +245,10 @@ def train() : setup_wandb( cf.with_wandb, cf, par_rank, 'train', mode='offline') cf.with_mixed_precision = True - cf.num_samples_per_epoch = 4096 - cf.num_samples_validate = 128 + cf.num_samples_per_epoch = 4096*12 + cf.num_samples_validate = 128*12 cf.num_loader_workers = 6 - + #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' # # in steps x lat_degrees x lon_degrees #cf.n_size = [36, 1*9*6, 1.*9*12] @@ -257,9 +274,10 @@ def train() : #################################################################################################### if __name__ == '__main__': - train() + train() -# wandb_id, epoch = '1jh2qvrx', 392 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 +# wandb_id, epoch = '66zlffty', 26 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 +# #wandb_id, epoch = 'fc5o31h2', 27 # epoch_continue = epoch # Trainer = Trainer_BERT diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 7278256..042bb62 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -42,14 +42,7 @@ import atmorep.utils.token_infos_transformations as token_infos_transformations -import atmorep.utils.utils as utils -from atmorep.utils.utils import shape_to_str -from atmorep.utils.utils import relMSELoss -from atmorep.utils.utils import Gaussian -from atmorep.utils.utils import CRPS -from atmorep.utils.utils import NetMode -from atmorep.utils.utils import sgn_exp -from atmorep.utils.utils import tokenize, detokenize +from atmorep.utils.utils import Gaussian, CRPS, get_weights, weighted_mse, NetMode, tokenize, detokenize from atmorep.datasets.data_writer import write_forecast, write_BERT, write_attention from atmorep.datasets.normalizer import denormalize @@ -236,11 +229,11 @@ def train( self, epoch): for batch_idx in range( model.len( NetMode.train)) : batch_data = self.model.next() - + (_, _ , _, tmis_list) = batch_data[0] with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=cf.with_mixed_precision): batch_data = self.prepare_batch( batch_data) preds, _ = self.model_ddp( batch_data) - loss, mse_loss, losses = self.loss( preds, batch_idx) + loss, mse_loss, losses = self.loss( preds, batch_idx, tmis_list) self.grad_scaler.scale(loss).backward() self.grad_scaler.step(self.optimizer) @@ -255,7 +248,7 @@ def train( self, epoch): # logging - if int((batch_idx * cf.batch_size) / 4) > ctr : + if int((batch_idx * cf.batch_size) / 8) > ctr : # wandb logging if cf.with_wandb and (0 == cf.par_rank) : @@ -361,6 +354,7 @@ def profile( self): ################################################### def validate( self, epoch, BERT_test_strategy = 'BERT'): + print('inside_validate') cf = self.cf BERT_strategy_train = cf.BERT_strategy cf.BERT_strategy = BERT_test_strategy @@ -377,8 +371,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): batch_data = self.model.next() if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory - (sources, token_infos, targets, tmis_list) = batch_data[0] - + (sources, _ , targets, tmis_list) = batch_data[0] log_sources = ( [source.detach().clone().cpu() for source in sources ], [target.detach().clone().cpu() for target in targets ], tmis_list) @@ -395,23 +388,27 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): self.test_loss( pred, target) # base line loss cur_loss = self.MSELoss( pred[0], target = target ).cpu().item() - + + print(cur_loss, flush = True) loss += cur_loss total_losses[ifield] += cur_loss ifield += 1 + print(f"total_loss {total_loss}", flush = True) total_loss += loss test_len += 1 - + # store detailed results on current test set for book keeping if cf.par_rank < cf.log_test_num_ranks : log_preds = [[p.detach().clone().cpu() for p in pred] for pred in preds] self.log_validate( epoch, it, log_sources, log_preds) if cf.attention: self.log_attention( epoch, it, atts) - + + print(f"FINAL total_loss {total_loss}", flush = True) # average over all nodes total_loss /= test_len * len(self.cf.fields_prediction) total_losses /= test_len + print(f"FINAL total_loss after ratio {total_loss}", flush = True) if cf.with_ddp : total_loss_cuda = total_loss.cuda() total_losses_cuda = total_losses.cuda() @@ -419,7 +416,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): dist.all_reduce( total_losses_cuda, op=torch.distributed.ReduceOp.AVG ) total_loss = total_loss_cuda.cpu() total_losses = total_losses_cuda.cpu() - + print(f"FINAL total_loss after DDP {total_loss}", flush = True) if 0 == cf.par_rank : print( 'validation loss for strategy={} at epoch {} : {}'.format( BERT_test_strategy, epoch, total_loss), @@ -446,7 +443,7 @@ def test_loss( self, pred, target) : pass ################################################### - def loss( self, preds, batch_idx = 0) : + def loss( self, preds, batch_idx = 0, tmidx_list = None) : # TODO: move implementations to individual files @@ -471,6 +468,50 @@ def loss( self, preds, batch_idx = 0) : loss_en += self.MSELoss( en, target = target) losses['mse_ensemble'].append( loss_en / pred[2].shape[1]) + if 'weighted_mse' in self.cf.losses : + loss_en = torch.tensor( 0., device=target.device) + field_info = cf.fields[idx] + nlvls = len(field_info[2]) + num_tokens = field_info[3] + token_size = field_info[4] + + #idx_loc = [tokens_masked_idx_list[0][vlvl][batch_idx] - np.prod(num_tokens) * batch_idx for vlvl in range(nlvls)] + #targets_temp = self.get_masked_data(field_info, target, tokens_masked_idx_list[0]) + #preds_temp = self.get_masked_data(field_info, pred[0], tokens_masked_idx_list[0]) + lats_mskd = [] + weights = [] + for vidx in range(nlvls): + for bidx in range(cf.batch_size): + + lats_idx = self.sources_idxs[bidx][1] + lons_idx = self.sources_idxs[bidx][2] + grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + grid = torch.from_numpy( np.array( np.broadcast_to( grid, + shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) + + grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) + + idx_base = tmidx_list[idx][vidx][bidx] + idx_loc = idx_base - np.prod(num_tokens) * bidx + #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] + + lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + lats_mskd.append(lats_mskd_b) + weights.append([get_weights(la) for la in lats_mskd_b]) + lats_mskd = torch.Tensor([l for l in lats_mskd]) + weights = torch.Tensor(np.array([w for batch in weights for w in batch])) + breakpoint() + weights = weights.view(*weights.shape, 1, 1).repeat(1, 1, token_size[0], token_size[2]).swapaxes(1, 2) + + weights = weights.reshape([weights.shape[0], -1]).to(target.get_device()) + + #target_temp = detokenize(target.reshape([nlvls, -1] + tok_size + ntokens).cpu().detach().numpy()) + #preds_temp = detokenize(preds[0].reshape([nlvls, ]).cpu().detach().numpy()) + for en in torch.transpose( pred[2], 1, 0) : + loss_en += weighted_mse( en, target, weights) + + losses['weighted_mse'].append( loss_en / pred[2].shape[1]) + # Generalized cross entroy loss for continuous distributions if 'stats' in self.cf.losses : stats_loss = Gaussian( target, pred[0], pred[1]) @@ -489,16 +530,11 @@ def loss( self, preds, batch_idx = 0) : if 'crps' in self.cf.losses : crps_loss = torch.mean( CRPS( target, pred[0], pred[1])) losses['crps'].append( crps_loss) - - #cosine weighted RMSE loss - if 'cos_weighted_mse' in self.cf.losses : - cw_mse_loss = torch.mean( CosW_MSELoss( pred[0], target = target)) - losses['cos_weighted_mse'].append( cw_mse_loss) loss = torch.tensor( 0., device=self.device_out) tot_weight = torch.tensor( 0., device=self.device_out) for key in losses : - #print( 'LOSS : {} :: {}'.format( key, losses[key])) + print( 'LOSS : {} :: {}'.format( key, losses[key])) for ifield, val in enumerate(losses[key]) : loss += self.loss_weights[ifield] * val.to( self.device_out) tot_weight += self.loss_weights[ifield] @@ -643,7 +679,8 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : target = detokenize( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) - coords_b, targ_coords_b = [], [] + coords_b = [] + for bidx in range(batch_size): dates = self.sources_info[bidx][0] lats = self.sources_info[bidx][1] @@ -657,7 +694,7 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : normalizer, year_base = self.model.normalizer( fidx, vidx, lats_idx, lons_idx) source[bidx,vidx] = denormalize( source[bidx,vidx], normalizer, dates, year_base) target[bidx,vidx] = denormalize( target[bidx,vidx], normalizer, dates_t, year_base) - + coords_b += [[dates, 90.-lats, lons, dates_t]] # append @@ -688,7 +725,7 @@ def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : normalizer, year_base = self.model.normalizer( self.fields_prediction_idx[fidx], vidx, lats_idx, lons_idx) pred[bidx,vidx] = denormalize( pred[bidx,vidx], normalizer, dates_t, year_base) ensemble[bidx,:,vidx] = denormalize(ensemble[bidx,:,vidx], normalizer, dates_t, year_base) - + # append preds_out.append( [fn[0], pred]) ensembles_out.append( [fn[0], ensemble]) @@ -710,6 +747,7 @@ def split_data(self, data, idx_list, token_size) : return [torch.split( data_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] def get_masked_data(self, field_info, data, idx_list, ensemble = False): + cf = self.cf batch_size = len(self.sources_info) num_levels = len(field_info[2]) @@ -747,7 +785,7 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : num_tokens = field_info[3] token_size = field_info[4] sources_b = detokenize( sources[fidx].numpy()) - + if is_predicted : targets_b = self.get_masked_data(field_info, targets[fidx], tokens_masked_idx_list[fidx]) preds_mu_b = self.get_masked_data(field_info, log_preds[fidx][0], tokens_masked_idx_list[fidx]) @@ -785,7 +823,7 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) grid_lons_toked = tokenize( grid[1], token_size).flatten( 0, 2) - + idx_loc = idx - np.prod(num_tokens) * bidx #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] lats_mskd = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) @@ -803,8 +841,9 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : if len(normalizer.shape) > 2: #local normalization lats_mskd_idx = np.where(np.isin(lats,la))[0] lons_mskd_idx = np.where(np.isin(lons,lo))[0] - normalizer_ii = normalizer[:, :, lats_mskd_idx, lons_mskd_idx] - + #normalizer_ii = normalizer[:, :, lats_mskd_idx, lons_mskd_idx] problems in python 3.9 + normalizer_ii = normalizer[:, :, lats_mskd_idx[0]:lats_mskd_idx[-1]+1, lons_mskd_idx[0]:lons_mskd_idx[-1]+1] + targets_b[vidx][bidx][ii] = denormalize(t, normalizer_ii, da, year_base) preds_mu_b[vidx][bidx][ii] = denormalize(p, normalizer_ii, da, year_base) preds_ens_b[vidx][bidx][ii] = denormalize(e, normalizer_ii, da, year_base) diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index 941bfa0..011f566 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -25,7 +25,7 @@ def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): ds_batch_item = ds_field.create_group( f'{name}={name_idx:05d}' ) ds_batch_item.create_dataset( 'data', data=data) ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=coords[0].astype(np.datetime64)) + ds_batch_item.create_dataset( 'datetime', data=coords[0].astype('datetime64[ns]')) ds_batch_item.create_dataset( 'lat', data=np.array(coords[1]).astype(np.float32)) ds_batch_item.create_dataset( 'lon', data=np.array(coords[2]).astype(np.float32)) return ds_batch_item diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 23e4059..2181843 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -206,9 +206,26 @@ def __iter__(self): sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources] token_infos = [torch.stack(tis_field).transpose(1,0) for tis_field in token_infos] sources = self.pre_batch( sources, token_infos ) + + # tmidx_list = sources[-1] + # breakpoint() + # for ifield, field_info in enumerate(self.fields): + # for ilevel, vl in enumerate(field_info[2]): + # idx_base = tmidx_list[ifield][ilevel][bidx] + # idx_loc = idx_base - np.prod(num_tokens) * bidx + #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] + # lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + #lons_mskd_b = np.array([np.unique(t) for t in grid_lons_toked[ idx_loc ].numpy()]) + + #weights_idx_list += [get_weights(la, lo) for la, lo in zip(lats_mskd_b, lons_mskd_b)] + # 1. retrieve token_masked_idx + # 2. get_weights_idx + # 3. propagate get_weights_idx + # TODO: implement (only required when prediction target comes from different data stream) targets, target_info = None, None target_idxs = None + yield ( sources, targets, (source_idxs, sources_infos), (target_idxs, target_info)) ################################################### @@ -283,8 +300,8 @@ def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : lon += side_len[1].item() - overlap[1].item() # adjust batch size if necessary so that the evaluations split up across batches of equal size - batch_size = num_tiles_lon - + batch_size = len(times_pos) #num_tiles_lon + print(batch_size) print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) self.set_data( times_pos, batch_size) diff --git a/atmorep/datasets/normalizer.py b/atmorep/datasets/normalizer.py index 0487af2..906989e 100644 --- a/atmorep/datasets/normalizer.py +++ b/atmorep/datasets/normalizer.py @@ -36,6 +36,7 @@ def normalize( data, norm, dates, year_base = 1979) : ###################################################### def normalize_local( data, mean, var) : + #breakpoint() data = (data - mean) / var return data diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index a8c5bc3..a4d1ea1 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -94,7 +94,6 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : batch_dim = source.shape[0] num_tokens = source.shape[1] - # masking_ratios = rng.random( batch_dim) * BERT_frac # number of tokens masked per batch entry nums_masked = np.ceil( num_tokens * masking_ratios).astype(np.int64) diff --git a/atmorep/transformer/transformer_attention.py b/atmorep/transformer/transformer_attention.py index 4356394..e312ea4 100644 --- a/atmorep/transformer/transformer_attention.py +++ b/atmorep/transformer/transformer_attention.py @@ -43,7 +43,6 @@ def __init__(self, dim_embed, num_heads, dropout_rate=0., with_qk_lnorm=True, wi assert 0 == dim_embed % num_heads self.dim_head_proj = int(dim_embed / num_heads) - self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) self.proj_heads = torch.nn.Linear( dim_embed, num_heads*3*self.dim_head_proj, bias = False) self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index f2f1271..3ae6280 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -352,6 +352,8 @@ def erf( x, mu=0., std_dev=1.) : val = c1 * ( 1./c2 - std_dev * torch.special.erf( (mu - x) / (c3 * std_dev) ) ) return val +######################################## + def CRPS( y, mu, std_dev) : # see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018. c1 = np.sqrt(1./np.pi) @@ -359,3 +361,20 @@ def CRPS( y, mu, std_dev) : t2 = 2. * Gaussian( (y-mu) / std_dev) val = std_dev * ( (y-mu)/std_dev * t1 + t2 - c1 ) return val + + +######################################## + +def get_weights(lats_idx, lat_min = -90., lat_max = 90., reso = 0.25): + lat_range = lat_max - lat_min + bins = lat_range/reso+1 + + theta_weight = np.array([np.cos(w) for w in np.arange( lat_max * np.pi/lat_range , lat_min * np.pi/lat_range, -np.pi/bins)], dtype = np.float32) + + return theta_weight[lats_idx] + +######################################## + +def weighted_mse(x, target, weights): + return torch.sum(weights * (x - target) **2 )/torch.sum(weights) + diff --git a/slurm_atmorep.sh b/slurm_atmorep.sh index c1cf9cf..ddb82c3 100755 --- a/slurm_atmorep.sh +++ b/slurm_atmorep.sh @@ -1,8 +1,9 @@ #!/bin/bash -x #SBATCH --account=ehpc03 -#SBATCH --time=0-0:09:59 +#SBATCH --time=0-71:59:59 #SBATCH --nodes=1 -#SBATCH --cpus-per-task=80 #####nodes * gpus/node * 20 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=20 #SBATCH --gres=gpu:4 #SBATCH --chdir=. #SBATCH --qos=acc_ehpc @@ -16,7 +17,7 @@ export UCX_TLS="^cma" export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=0 #,1,2,3 # so processes know who to talk to export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" @@ -35,11 +36,13 @@ echo "Number of Nodes: $SLURM_JOB_NUM_NODES" echo "Number of Tasks: $SLURM_NTASKS" date +export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} + CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_train_${SLURM_JOBID} mkdir ${CONFIG_DIR} cp ${SLURM_SUBMIT_DIR}/atmorep/core/train.py ${CONFIG_DIR} echo "${CONFIG_DIR}/train.py" -srun --label --cpu-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/train.py > output/output_${SLURM_JOBID}.txt +srun --label --cpu-bind=v --accel-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/train.py > output/output_${SLURM_JOBID}.txt echo "Finished job." date diff --git a/slurm_atmorep_evaluate.sh b/slurm_atmorep_evaluate.sh index 017e07a..548a81f 100755 --- a/slurm_atmorep_evaluate.sh +++ b/slurm_atmorep_evaluate.sh @@ -1,9 +1,9 @@ #!/bin/bash -x #SBATCH --account=ehpc03 -#SBATCH --time=0-02:30:00 +#SBATCH --time=0-0:10:00 #SBATCH --nodes=1 -#SBATCH --cpus-per-task=80 #####nodes * gpus/node * 20 -#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=40 +#SBATCH --gres=gpu:2 #SBATCH --chdir=. #SBATCH --qos=acc_ehpc #SBATCH --output=logs/atmorep-%x.%j.out @@ -35,6 +35,8 @@ echo "Number of Nodes: $SLURM_JOB_NUM_NODES" echo "Number of Tasks: $SLURM_NTASKS" date +export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} + CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_eval_${SLURM_JOBID} mkdir ${CONFIG_DIR} cp ${SLURM_SUBMIT_DIR}/atmorep/core/evaluate.py ${CONFIG_DIR} From 9c49dedbb238b4b9e0a8b81f5b19e65913349564 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 3 Jul 2024 19:20:53 +0200 Subject: [PATCH 43/66] final weighted average LOSS --- atmorep/core/atmorep_model.py | 8 +-- atmorep/core/train.py | 8 +-- atmorep/core/trainer.py | 58 +++++++++------------ atmorep/datasets/multifield_data_sampler.py | 47 ++++++++++------- 4 files changed, 63 insertions(+), 58 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 79834d4..7db37f1 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -137,12 +137,12 @@ def normalizer( self, field, vl_idx, lats_idx, lons_idx ) : def mode( self, mode : NetMode) : if mode == NetMode.train : - #self.data_loader_iter = iter(self.data_loader_train) - self.data_loader_iter = iter(self.dataset_train) + self.data_loader_iter = iter(self.data_loader_train) + #self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : - #self.data_loader_iter = iter(self.data_loader_test) - self.data_loader_iter = iter(self.dataset_test) + self.data_loader_iter = iter(self.data_loader_test) + #self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 1fa8283..d76b379 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -210,11 +210,11 @@ def train() : # loss # supported: see Trainer for supported losses #cf.losses = ['mse', 'stats'] - #cf.losses = ['mse_ensemble'] #, 'stats'] - cf.losses = ['weighted_mse'] + #cf.losses = ['mse_ensemble', 'stats'] + #cf.losses = ['weighted_mse', 'stats'] # cf.losses = ['mse'] # cf.losses = ['stats'] - # cf.losses = ['crps'] + cf.losses = ['mse_ensemble', 'crps'] # training cf.optimizer_zero = False cf.lr_start = 5. * 10e-7 @@ -226,7 +226,7 @@ def train() : cf.lat_sampling_weighted = False #True # BERT # strategies: 'BERT', 'forecast', 'temporal_interpolation', 'identity' - cf.BERT_strategy = 'forecast' #'BERT' + cf.BERT_strategy = 'BERT' #'BERT' cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 042bb62..841c520 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -229,11 +229,12 @@ def train( self, epoch): for batch_idx in range( model.len( NetMode.train)) : batch_data = self.model.next() - (_, _ , _, tmis_list) = batch_data[0] + _, _, _, tmksd_list, weight_list = batch_data[0] with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=cf.with_mixed_precision): batch_data = self.prepare_batch( batch_data) preds, _ = self.model_ddp( batch_data) - loss, mse_loss, losses = self.loss( preds, batch_idx, tmis_list) + #breakpoint() + loss, mse_loss, losses = self.loss( preds, batch_idx, tmksd_list, weight_list) self.grad_scaler.scale(loss).backward() self.grad_scaler.step(self.optimizer) @@ -371,7 +372,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): batch_data = self.model.next() if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory - (sources, _ , targets, tmis_list) = batch_data[0] + (sources, _ , targets, tmis_list, _) = batch_data[0] log_sources = ( [source.detach().clone().cpu() for source in sources ], [target.detach().clone().cpu() for target in targets ], tmis_list) @@ -443,7 +444,7 @@ def test_loss( self, pred, target) : pass ################################################### - def loss( self, preds, batch_idx = 0, tmidx_list = None) : + def loss( self, preds, batch_idx = 0, tmidx_list = None, weights_list = None) : # TODO: move implementations to individual files @@ -475,38 +476,31 @@ def loss( self, preds, batch_idx = 0, tmidx_list = None) : num_tokens = field_info[3] token_size = field_info[4] - #idx_loc = [tokens_masked_idx_list[0][vlvl][batch_idx] - np.prod(num_tokens) * batch_idx for vlvl in range(nlvls)] - #targets_temp = self.get_masked_data(field_info, target, tokens_masked_idx_list[0]) - #preds_temp = self.get_masked_data(field_info, pred[0], tokens_masked_idx_list[0]) - lats_mskd = [] weights = [] - for vidx in range(nlvls): - for bidx in range(cf.batch_size): + # for vidx in range(nlvls): + # for bidx in range(cf.batch_size): - lats_idx = self.sources_idxs[bidx][1] - lons_idx = self.sources_idxs[bidx][2] - grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 - grid = torch.from_numpy( np.array( np.broadcast_to( grid, - shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) + # lats_idx = self.sources_idxs[bidx][1] + # lons_idx = self.sources_idxs[bidx][2] - grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) + # idx_base = tmidx_list[idx][vidx][bidx] + # idx_loc = idx_base - np.prod(num_tokens) * bidx + + # grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + # grid = torch.from_numpy( np.array( np.broadcast_to( grid, + # shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) + # grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) - idx_base = tmidx_list[idx][vidx][bidx] - idx_loc = idx_base - np.prod(num_tokens) * bidx - #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] - - lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) - lats_mskd.append(lats_mskd_b) - weights.append([get_weights(la) for la in lats_mskd_b]) - lats_mskd = torch.Tensor([l for l in lats_mskd]) - weights = torch.Tensor(np.array([w for batch in weights for w in batch])) - breakpoint() + # lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + + # weights.append([get_weights(la) for la in lats_mskd_b]) + #breakpoint() + + weights = torch.Tensor(np.array([w for batch in weights_list[idx] for w in batch])) weights = weights.view(*weights.shape, 1, 1).repeat(1, 1, token_size[0], token_size[2]).swapaxes(1, 2) - weights = weights.reshape([weights.shape[0], -1]).to(target.get_device()) - - #target_temp = detokenize(target.reshape([nlvls, -1] + tok_size + ntokens).cpu().detach().numpy()) - #preds_temp = detokenize(preds[0].reshape([nlvls, ]).cpu().detach().numpy()) + # weights = torch.ones(weights.shape).to(target.get_device()) + #breakpoint() for en in torch.transpose( pred[2], 1, 0) : loss_en += weighted_mse( en, target, weights) @@ -534,7 +528,7 @@ def loss( self, preds, batch_idx = 0, tmidx_list = None) : loss = torch.tensor( 0., device=self.device_out) tot_weight = torch.tensor( 0., device=self.device_out) for key in losses : - print( 'LOSS : {} :: {}'.format( key, losses[key])) + # print( 'LOSS : {} :: {}'.format( key, losses[key])) for ifield, val in enumerate(losses[key]) : loss += self.loss_weights[ifield] * val.to( self.device_out) tot_weight += self.loss_weights[ifield] @@ -575,7 +569,7 @@ def prepare_batch( self, xin) : # unpack loader output # xin[0] since BERT does not have targets - (sources, token_infos, targets, fields_tokens_masked_idx_list) = xin[0] + (sources, token_infos, targets, fields_tokens_masked_idx_list, _) = xin[0] (self.sources_idxs, self.sources_info) = xin[2] # network input diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 2181843..5ad9dc4 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -25,7 +25,7 @@ # from atmorep.datasets.normalizer_global import NormalizerGlobal # from atmorep.datasets.normalizer_local import NormalizerLocal from atmorep.datasets.normalizer import normalize -from atmorep.utils.utils import tokenize +from atmorep.utils.utils import tokenize, get_weights class MultifieldDataSampler( torch.utils.data.IterableDataset): @@ -140,9 +140,7 @@ def __iter__(self): data_tt_sfc = self.ds['data_sfc'].oindex[idxs_t] data_tt = self.ds['data'].oindex[idxs_t] for sidx in range(self.batch_size) : - # i_bidx = self.idxs_perm_t[bidx] - # idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) - + idx = self.idxs_perm[bidx*self.batch_size+sidx] # slight asymetry with offset by res/2 is required to match desired token count lat_ran = np.where(np.logical_and(lats>idx[0]-ns_2[1]-res[0]/2.,lats Date: Tue, 23 Jul 2024 17:28:40 +0200 Subject: [PATCH 44/66] clean code version. time still off. --- atmorep/core/atmorep_model.py | 1 - atmorep/core/evaluate.py | 66 +++++++-------------- atmorep/core/evaluator.py | 16 ++--- atmorep/core/train.py | 55 ++++++++--------- atmorep/core/trainer.py | 48 ++++----------- atmorep/datasets/multifield_data_sampler.py | 2 +- atmorep/datasets/normalizer.py | 1 - atmorep/utils/utils.py | 27 ++++++++- slurm_atmorep.sh | 4 +- slurm_atmorep_evaluate.sh | 2 +- 10 files changed, 95 insertions(+), 127 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 7db37f1..06c4843 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -73,7 +73,6 @@ def set_global( self, mode : NetMode, times, batch_size = -1, num_loader_workers cf = self.net.cf if batch_size < 0 : batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test - print("dates:", times) dataset = self.dataset_train if mode == NetMode.train else self.dataset_test dataset.set_global( times, batch_size, cf.token_overlap) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 7eccb21..694f430 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -16,38 +16,10 @@ from atmorep.core.evaluator import Evaluator import time + if __name__ == '__main__': # models for individual fields - - #2 nodes - #model_id='1b43bynq' - # model_id='p20z3ilu' - # model_id='10y42b1u' - #model_id='99tb5lcy' - # model_id='sn6h8wvq' - # model_id='085saknn' - - #1node train continue - #model_id='h7orvjna' - #model_id='ocpn87si' - - # 1node - #temperature - #model_id='66zlffty' - #model_id='fmzy4mxr' - model_id='ezi4shmb' - - #velocity_u - #model_id='hg8cy3c4' - #model_id='av0rp1mj' - #model_id='fc5o31h2' - - #specific_humidity - #model_id='w965qy0o' - #model_id='gpksrtrl' - #model_id='c6am1m3j' - #model_id = '4nvwbetz' # vorticity #model_id = 'oxpycr7w' # divergence #model_id = '1565pb1f' # specific_humidity @@ -71,28 +43,30 @@ # e.g. global_forecast where a start date can be specified # BERT masked token model - #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} - #mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123], 'attention' : False} - mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} + #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} + # BERT forecast mode - #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'fields[0][2]' : [123], 'attention' : False } + #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } #temporal interpolation #idx_time_mask: list of relative time positions of the masked tokens within the cube wrt num_tokens[0] - #mode, options = 'temporal_interpolation', {'fields[0][2]' : [123], 'idx_time_mask': [5,6,7], 'attention' : False } + #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'attention' : False } # BERT forecast with patching to obtain global forecast - # mode, options = 'global_forecast', { #'fields[0][2]' : [114], #[123, 137], #[105, 137], - # #'dates' : [[2021, 1, 10, 18]], - # 'dates' : [ [2021, 1, 10, 18] - # # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], - # # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], - # # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], - # # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] - # ], - # 'token_overlap' : [0, 0], - # 'forecast_num_tokens' : 2, - # 'attention' : False } + mode, options = 'global_forecast', { + 'dates' : [[2021, 1, 10, 18]], + # 'dates' : [ #[2021, 1, 10, 18] + # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], + # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], + # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], + # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] + # ], + 'token_overlap' : [0, 0], + 'forecast_num_tokens' : 2, + 'attention' : False } + + file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' + now = time.time() - Evaluator.evaluate( mode, model_id, options) + Evaluator.evaluate( mode, model_id, file_path, options) print("time", time.time() - now) diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 1b1d0a5..687c62a 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -60,7 +60,6 @@ def run( cf, model_id, model_epoch, devices) : if not hasattr(cf, 'batch_size_validation'): cf.batch_size_validation = cf.batch_size_max - cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' cf.with_mixed_precision = True # set/over-write options as desired @@ -73,7 +72,7 @@ def run( cf, model_id, model_epoch, devices) : ############################################## @staticmethod - def evaluate( mode, model_id, args = {}, model_epoch=-2) : + def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : # SLURM_TASKS_PER_NODE is controlled by #SBATCH --ntasks-per-node=1; should be 1 for multiformer with_ddp = True @@ -81,18 +80,18 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : with_ddp = False num_accs_per_task = 1 else : - num_accs_per_task = 1 #int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) - #devices = init_torch( num_accs_per_task) - devices = ['cuda'] + num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) + devices = init_torch( num_accs_per_task) + #devices = ['cuda'] par_rank, par_size = setup_ddp( with_ddp) - + cf.file_path = file_path cf = Config().load_json( model_id) cf.with_wandb = True cf.with_ddp = with_ddp cf.par_rank = par_rank cf.par_size = par_size - cf.losses = cf.losses + ['weighted_mse'] + cf.losses = cf.losses # overwrite old config cf.attention = False setup_wandb( cf.with_wandb, cf, par_rank, '', mode='offline') @@ -149,9 +148,6 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : cf.num_loader_workers = 12 #1 cf.log_test_num_ranks = 1 - if not hasattr(cf, 'file_path'): - cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' - if not hasattr(cf, 'batch_size'): cf.batch_size = 196 #14 if not hasattr(cf, 'batch_size_validation'): diff --git a/atmorep/core/train.py b/atmorep/core/train.py index d76b379..be33954 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -50,17 +50,20 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : cf.lr_max = 0.00005*3 cf.num_samples_per_epoch = 4096*12 cf.num_samples_validate = 128*12 + + cf.losses = ['weighted_mse', 'stats'] + # name has changed but ensure backward compatibility - # if hasattr( cf, 'loader_num_workers') : - # cf.num_loader_workers = cf.loader_num_workers - # if not hasattr( cf, 'n_size'): - # cf.n_size = [36, 0.25*9*6, 0.25*9*12] - # if not hasattr(cf, 'num_samples_per_epoch'): - # cf.num_samples_per_epoch = 1024 - # if not hasattr(cf, 'num_samples_validate'): - # cf.num_samples_validate = 128 - # if not hasattr(cf, 'with_mixed_precision'): - # cf.with_mixed_precision = True + if hasattr( cf, 'loader_num_workers') : + cf.num_loader_workers = cf.loader_num_workers + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = True # cf.years_train = [2021] # list( range( 1980, 2018)) # cf.years_test = [2021] #[2018] @@ -90,9 +93,10 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : #################################################################################################### def train() : - num_accs_per_task = 1 #int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) + num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) #device = ['cuda'] + with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -153,10 +157,6 @@ def train() : # ['total_precip', [1, 1536, ['velocity_u', 'velocity_v', 'velocity_z', 'specific_humidity'], 3], # [0], # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05]] ] - - # cf.fields_prediction = [['velocity_u', 0.225], ['velocity_v', 0.225], - # ['specific_humidity', 0.15], ['velocity_z', 0.1], ['temperature', 0.2], - # ['total_precip', 0.1] ] cf.fields_targets = [] @@ -209,12 +209,7 @@ def train() : cf.net_tail_num_layers = 0 # loss # supported: see Trainer for supported losses - #cf.losses = ['mse', 'stats'] - #cf.losses = ['mse_ensemble', 'stats'] - #cf.losses = ['weighted_mse', 'stats'] - # cf.losses = ['mse'] - # cf.losses = ['stats'] - cf.losses = ['mse_ensemble', 'crps'] + cf.losses = ['mse_ensemble', 'stats'] # training cf.optimizer_zero = False cf.lr_start = 5. * 10e-7 @@ -223,10 +218,10 @@ def train() : cf.weight_decay = 0.05 cf.lr_decay_rate = 1.025 cf.lr_start_epochs = 3 - cf.lat_sampling_weighted = False #True + # BERT # strategies: 'BERT', 'forecast', 'temporal_interpolation', 'identity' - cf.BERT_strategy = 'BERT' #'BERT' + cf.BERT_strategy = 'BERT' cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution @@ -274,11 +269,13 @@ def train() : #################################################################################################### if __name__ == '__main__': - train() +# train() -# wandb_id, epoch = '66zlffty', 26 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 -# #wandb_id, epoch = 'fc5o31h2', 27 -# epoch_continue = epoch + #wandb_id, epoch = '66zlffty', 26 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 + wandb_id, epoch = 'h7orvjna', 82 + #wandb_id, epoch = 'ocpn87si', 103 + #wandb_id, epoch = 'fc5o31h2', 27 + epoch_continue = epoch -# Trainer = Trainer_BERT -# train_continue( wandb_id, epoch, Trainer, epoch_continue) + Trainer = Trainer_BERT + train_continue( wandb_id, epoch, Trainer, epoch_continue) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 841c520..02b5b3e 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -23,9 +23,7 @@ from pathlib import Path import os import datetime -from typing import TypeVar import functools -import pandas as pd import wandb # import horovod.torch as hvd @@ -42,7 +40,7 @@ import atmorep.utils.token_infos_transformations as token_infos_transformations -from atmorep.utils.utils import Gaussian, CRPS, get_weights, weighted_mse, NetMode, tokenize, detokenize +from atmorep.utils.utils import Gaussian, CRPS, kernel_crps, weighted_mse, NetMode, tokenize, detokenize from atmorep.datasets.data_writer import write_forecast, write_BERT, write_attention from atmorep.datasets.normalizer import denormalize @@ -355,7 +353,7 @@ def profile( self): ################################################### def validate( self, epoch, BERT_test_strategy = 'BERT'): - print('inside_validate') + cf = self.cf BERT_strategy_train = cf.BERT_strategy cf.BERT_strategy = BERT_test_strategy @@ -390,11 +388,10 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): # base line loss cur_loss = self.MSELoss( pred[0], target = target ).cpu().item() - print(cur_loss, flush = True) loss += cur_loss total_losses[ifield] += cur_loss ifield += 1 - print(f"total_loss {total_loss}", flush = True) + total_loss += loss test_len += 1 @@ -404,12 +401,11 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): self.log_validate( epoch, it, log_sources, log_preds) if cf.attention: self.log_attention( epoch, it, atts) - - print(f"FINAL total_loss {total_loss}", flush = True) + # average over all nodes total_loss /= test_len * len(self.cf.fields_prediction) total_losses /= test_len - print(f"FINAL total_loss after ratio {total_loss}", flush = True) + if cf.with_ddp : total_loss_cuda = total_loss.cuda() total_losses_cuda = total_losses.cuda() @@ -417,7 +413,7 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): dist.all_reduce( total_losses_cuda, op=torch.distributed.ReduceOp.AVG ) total_loss = total_loss_cuda.cpu() total_losses = total_losses_cuda.cpu() - print(f"FINAL total_loss after DDP {total_loss}", flush = True) + if 0 == cf.par_rank : print( 'validation loss for strategy={} at epoch {} : {}'.format( BERT_test_strategy, epoch, total_loss), @@ -472,35 +468,12 @@ def loss( self, preds, batch_idx = 0, tmidx_list = None, weights_list = None) : if 'weighted_mse' in self.cf.losses : loss_en = torch.tensor( 0., device=target.device) field_info = cf.fields[idx] - nlvls = len(field_info[2]) - num_tokens = field_info[3] token_size = field_info[4] - weights = [] - # for vidx in range(nlvls): - # for bidx in range(cf.batch_size): - - # lats_idx = self.sources_idxs[bidx][1] - # lons_idx = self.sources_idxs[bidx][2] - - # idx_base = tmidx_list[idx][vidx][bidx] - # idx_loc = idx_base - np.prod(num_tokens) * bidx - - # grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 - # grid = torch.from_numpy( np.array( np.broadcast_to( grid, - # shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) - # grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) - - # lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) - - # weights.append([get_weights(la) for la in lats_mskd_b]) - #breakpoint() - weights = torch.Tensor(np.array([w for batch in weights_list[idx] for w in batch])) weights = weights.view(*weights.shape, 1, 1).repeat(1, 1, token_size[0], token_size[2]).swapaxes(1, 2) weights = weights.reshape([weights.shape[0], -1]).to(target.get_device()) - # weights = torch.ones(weights.shape).to(target.get_device()) - #breakpoint() + for en in torch.transpose( pred[2], 1, 0) : loss_en += weighted_mse( en, target, weights) @@ -524,11 +497,16 @@ def loss( self, preds, batch_idx = 0, tmidx_list = None, weights_list = None) : if 'crps' in self.cf.losses : crps_loss = torch.mean( CRPS( target, pred[0], pred[1])) losses['crps'].append( crps_loss) + + if 'kernel_crps' in self.cf.losses : + kcrps_loss = torch.mean( kernel_crps( target,torch.transpose( pred[2], 1, 0))) + losses['kernel_crps'].append( kcrps_loss) + loss = torch.tensor( 0., device=self.device_out) tot_weight = torch.tensor( 0., device=self.device_out) for key in losses : - # print( 'LOSS : {} :: {}'.format( key, losses[key])) + #print( 'LOSS : {} :: {}'.format( key, losses[key])) for ifield, val in enumerate(losses[key]) : loss += self.loss_weights[ifield] * val.to( self.device_out) tot_weight += self.loss_weights[ifield] diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 5ad9dc4..b34d063 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -312,7 +312,7 @@ def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : # adjust batch size if necessary so that the evaluations split up across batches of equal size batch_size = len(times_pos) #num_tiles_lon - print(batch_size) + print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) self.set_data( times_pos, batch_size) diff --git a/atmorep/datasets/normalizer.py b/atmorep/datasets/normalizer.py index 906989e..0487af2 100644 --- a/atmorep/datasets/normalizer.py +++ b/atmorep/datasets/normalizer.py @@ -36,7 +36,6 @@ def normalize( data, norm, dates, year_base = 1979) : ###################################################### def normalize_local( data, mean, var) : - #breakpoint() data = (data - mean) / var return data diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index 3ae6280..fb2c40c 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -22,7 +22,7 @@ import wandb import code from calendar import monthrange - +#import properscoring as ps import numpy as np import torch.distributed as dist @@ -353,8 +353,12 @@ def erf( x, mu=0., std_dev=1.) : return val ######################################## +# def CRPS_ps( y, mu, std_dev) : +# val = ps.crps_gaussian(y.cpu().detach().numpy(), mu=mu.cpu().detach().numpy(), sig=std_dev.cpu().detach().numpy()) +# return torch.tensor(val) def CRPS( y, mu, std_dev) : + # see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018. c1 = np.sqrt(1./np.pi) t1 = 2. * erf( (y-mu) / std_dev) - 1. @@ -362,6 +366,27 @@ def CRPS( y, mu, std_dev) : val = std_dev * ( (y-mu)/std_dev * t1 + t2 - c1 ) return val +######################################## +# def kernel_crps_ps( target, ens) : +# #breakpoint() +# val = ps.crps_ensemble(target.cpu().detach().numpy(), ens.permute([1,2,0]).cpu().detach().numpy()) +# return torch.tensor(val) + +def kernel_crps( target, ens, fair = True) : + #breakpoint() + ens_size = ens.shape[0] + mae = torch.cat( [(target - mem).abs().mean().unsqueeze(0) for mem in ens], 0).mean() + + if ens_size == 1: + return mae + + coef = -1.0 / (2.0 * ens_size * (ens_size - 1)) if fair else -1.0 / (2.0 * ens_size**2) + temp = [(p1 - p2).abs().sum() for p1 in ens for p2 in ens] + # breakpoint() + ens_var = coef * torch.tensor( [(p1 - p2).abs().sum() for p1 in ens for p2 in ens]).sum() + ens_var /= (ens.shape[1]*ens.shape[2]) + + return mae + ens_var ######################################## diff --git a/slurm_atmorep.sh b/slurm_atmorep.sh index ddb82c3..2f2cbfc 100755 --- a/slurm_atmorep.sh +++ b/slurm_atmorep.sh @@ -1,6 +1,6 @@ #!/bin/bash -x #SBATCH --account=ehpc03 -#SBATCH --time=0-71:59:59 +#SBATCH --time=0-24:00:00 #SBATCH --nodes=1 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=20 @@ -17,7 +17,7 @@ export UCX_TLS="^cma" export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export CUDA_VISIBLE_DEVICES=0 #,1,2,3 +export CUDA_VISIBLE_DEVICES=0,1,2,3 # so processes know who to talk to export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" diff --git a/slurm_atmorep_evaluate.sh b/slurm_atmorep_evaluate.sh index 548a81f..32e160a 100755 --- a/slurm_atmorep_evaluate.sh +++ b/slurm_atmorep_evaluate.sh @@ -1,6 +1,6 @@ #!/bin/bash -x #SBATCH --account=ehpc03 -#SBATCH --time=0-0:10:00 +#SBATCH --time=0-3:30:00 #SBATCH --nodes=1 #SBATCH --cpus-per-task=40 #SBATCH --gres=gpu:2 From 46b22b8d3cbfa1881ca77d6930dd80a8ca80a897 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 24 Jul 2024 18:26:01 +0200 Subject: [PATCH 45/66] add test --- atmorep/tests/__init__.py | 1 + atmorep/tests/conftest.py | 5 ++ atmorep/tests/test_utils.py | 54 +++++++++++++++ atmorep/tests/validation_test.py | 115 +++++++++++++++++++++++++++++++ 4 files changed, 175 insertions(+) create mode 100644 atmorep/tests/__init__.py create mode 100644 atmorep/tests/conftest.py create mode 100644 atmorep/tests/test_utils.py create mode 100644 atmorep/tests/validation_test.py diff --git a/atmorep/tests/__init__.py b/atmorep/tests/__init__.py new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/atmorep/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/atmorep/tests/conftest.py b/atmorep/tests/conftest.py new file mode 100644 index 0000000..248e16e --- /dev/null +++ b/atmorep/tests/conftest.py @@ -0,0 +1,5 @@ +def pytest_addoption(parser): + parser.addoption("--field", action="store", help="field to run the test on") + parser.addoption("--model_id", action="store", help="wandb ID of the atmorep model") + parser.addoption("--epoch", action="store", help="field to run the test on", default = "0") + diff --git a/atmorep/tests/test_utils.py b/atmorep/tests/test_utils.py new file mode 100644 index 0000000..bf6d5ed --- /dev/null +++ b/atmorep/tests/test_utils.py @@ -0,0 +1,54 @@ +import numpy as np + +def era5_fname(): + return "/gpfs/scratch/ehpc03/data/{}/ml{}/era5_{}_y{}_m{}_ml{}.grib" + +def atmorep_pred(): + return "./results/id{}/results_id{}_epoch{}_pred.zarr" + +def atmorep_target(): + return "./results/id{}/results_id{}_epoch{}_target.zarr" + +def grib_index(field): + grib_idxs = {"velocity_u": "u", + "temperature": "t", + "total_precip": "tp", + "velocity_v": "v", + "velocity_z": "z", + "specific_humidity": "q"} + + return grib_idxs[field] + +###################################### + +def check_lats(lats_pred, lats_target): + assert (lats_pred[:] == lats_target[:]).all(), "Mismatch between latitudes" + assert (lats_pred[:] <= 90.).all(), f"latitudes are between {np.amin(lats_pred)}- {np.amax(lats_pred)}" + assert (lats_pred[:] >= -90.).all(), f"latitudes are between {np.amin(lats_pred)}- {np.amax(lats_pred)}" + +def check_lons(lons_pred, lons_target): + assert (lons_pred[:] == lons_target[:]).all(), "Mismatch between longitudes" + assert (lons_pred[:] >= 0.).all(), "longitudes are between {np.amin(lons_pred)}- {np.amax(lons_pred)}" + assert (lons_pred[:] <= 360.).all(), "longitudes are between {np.amin(lons_pred)}- {np.amax(lons_pred)}" + +def check_datetimes(datetimes_pred, datetimes_target): + assert (datetimes_pred == datetimes_target), "Mismatch between datetimes" + +###################################### + +#calculate RMSE +def compute_RMSE(pred, target): + return np.sqrt(np.mean((pred-target)**2)) + + +def get_max_RMSE(field): + #TODO: optimize thresholds + values = {"temperature" : 1.8, + "velocity_u" : 0.005, #???? + "velocity_v": 0.005, #???? + "velocity_z": 0.005, #???? + "specific_humidity": 0.7, #???? + "total_precip": 9999, #????? + } + + return values[field] \ No newline at end of file diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py new file mode 100644 index 0000000..cb87c16 --- /dev/null +++ b/atmorep/tests/validation_test.py @@ -0,0 +1,115 @@ +import pytest +import zarr +import cfgrib +import xarray as xr +import pandas as pd +import numpy as np +import random as rnd +import warnings +import os + +from atmorep.tests.test_utils import * + + +@pytest.fixture +def field(request): + return request.config.getoption("field") + +@pytest.fixture +def model_id(request): + return request.config.getoption("model_id") + +@pytest.fixture +def epoch(request): + request.config.getoption("epoch") + +################################################################## + +#@pytest.mark.parametrize(metafunc): #"field, level, model_id", [("temperature", 105, "ztsut0mr")]) +def test_datetime(field, model_id, epoch = 0): + + """ + Check against ERA5 timestamps + """ + #field, model_id, epoch = get_fixtures(metafunc) + + level = 137 + store = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) + atmorep = zarr.group(store) + + #TODO: make it more elegant + ml_idx = np.where(atmorep[f"{field}/sample=00000"].ml[:] == level)[0].tolist()[0] + + nsamples = min(len(atmorep[field]), 50) + samples = rnd.sample(range(len(atmorep[field])), nsamples) + + for s in samples: + + data = atmorep[f"{field}/sample={s:05d}"].data[ml_idx, 0] + datetime = pd.Timestamp(atmorep[f"{field}/sample={s:05d}"].datetime[0]) + lats = atmorep[f"{field}/sample={s:05d}"].lat + lons = atmorep[f"{field}/sample={s:05d}"].lon + + year, month = datetime.year, str(datetime.month).zfill(2) + + era5_path = era5_fname().format(field, level, field, year, month, level) + if not os.path.isfile(era5_path): + warnings.warn(UserWarning((f"Timestamp {datetime} not found in ERA5. Skipping"))) + continue + era5 = xr.open_dataset(era5_path, engine = "cfgrib")[grib_index(field)].sel(time = datetime, latitude = lats, longitude = lons) + + assert (data[0] == era5.values[0]).all(), "Mismatch between ERA5 and AtmoRep Timestamps" + +############################################################################# + +#@pytest.mark.parametrize("field, model_id", [("temperature", "ztsut0mr")]) + +def test_coordinates(field, model_id, epoch = 0): + """ + Check that coordinates match between target and prediction. + Check also that latitude and longitudes are in geographical coordinates + """ + #field, model_id, epoch = get_fixtures(metafunc) + store_t = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) + target = zarr.group(store_t) + + store_p = zarr.ZipStore(atmorep_pred().format(model_id, model_id, str(epoch).zfill(5))) + pred = zarr.group(store_p) + + nsamples = min(len(target[field]), 50) + samples = rnd.sample(range(len(target[field])), nsamples) + + for s in samples: + datetime_target = [pd.Timestamp(i) for i in target[f"{field}/sample={s:05d}"].datetime] + lats_target = target[f"{field}/sample={s:05d}"].lat + lons_target = target[f"{field}/sample={s:05d}"].lon + + datetime_pred = [pd.Timestamp(i) for i in pred[f"{field}/sample={s:05d}"].datetime] + lats_pred = pred[f"{field}/sample={s:05d}"].lat + lons_pred = pred[f"{field}/sample={s:05d}"].lon + + check_lats(lats_pred, lats_target) + check_lons(lons_pred, lons_target) + check_datetimes(datetime_pred, datetime_target) + +######################################################################### + +#@pytest.mark.parametrize("field, model_id", [("temperature", "ztsut0mr")]) + +def test_rmse(field, model_id, epoch = 0): + # field, model_id, epoch = get_fixtures(metafunc) + + store_t = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) + target = zarr.group(store_t) + + store_p = zarr.ZipStore(atmorep_pred().format(model_id, model_id, str(epoch).zfill(5))) + pred = zarr.group(store_p) + + nsamples = min(len(target[field]), 50) + samples = rnd.sample(range(len(target[field])), nsamples) + + for s in samples: + sample_target = target[f"{field}/sample={s:05d}"].data[:] + sample_pred = pred[f"{field}/sample={s:05d}"].data[:] + + assert compute_RMSE(sample_target, sample_pred).mean() < get_max_RMSE(field) From fe05e2826a16fb85151170001c4198bd07d324cb Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 24 Jul 2024 18:29:06 +0200 Subject: [PATCH 46/66] add test --- atmorep/tests/validation_test.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py index cb87c16..88ce98d 100644 --- a/atmorep/tests/validation_test.py +++ b/atmorep/tests/validation_test.py @@ -10,6 +10,7 @@ from atmorep.tests.test_utils import * +# run it with e.g. pytest -s atmorep/tests/validation_test.py --field temperature --model_id ztsut0mr @pytest.fixture def field(request): @@ -25,13 +26,12 @@ def epoch(request): ################################################################## -#@pytest.mark.parametrize(metafunc): #"field, level, model_id", [("temperature", 105, "ztsut0mr")]) def test_datetime(field, model_id, epoch = 0): """ - Check against ERA5 timestamps + Check against ERA5 timestamps. + 50 random samples. """ - #field, model_id, epoch = get_fixtures(metafunc) level = 137 store = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) @@ -62,14 +62,13 @@ def test_datetime(field, model_id, epoch = 0): ############################################################################# -#@pytest.mark.parametrize("field, model_id", [("temperature", "ztsut0mr")]) - def test_coordinates(field, model_id, epoch = 0): """ Check that coordinates match between target and prediction. Check also that latitude and longitudes are in geographical coordinates + 50 random samples. """ - #field, model_id, epoch = get_fixtures(metafunc) + store_t = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) target = zarr.group(store_t) @@ -94,11 +93,11 @@ def test_coordinates(field, model_id, epoch = 0): ######################################################################### -#@pytest.mark.parametrize("field, model_id", [("temperature", "ztsut0mr")]) - def test_rmse(field, model_id, epoch = 0): - # field, model_id, epoch = get_fixtures(metafunc) - + """ + Test that for each field the RMSE does not exceed a certain value. + 50 random samples. + """ store_t = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) target = zarr.group(store_t) From cf3a1650c215368b642ce66b8c4a2b34c0c0668e Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 25 Jul 2024 12:11:27 +0200 Subject: [PATCH 47/66] include validation tests within evaluate.py --- atmorep/core/evaluate.py | 35 +++++++++++++------------ atmorep/core/evaluator.py | 11 ++++++-- atmorep/datasets/data_writer.py | 18 ++++++------- atmorep/tests/test_utils.py | 4 +++ atmorep/tests/validation_test.py | 44 ++++++++++++++++---------------- setup.py | 2 +- 6 files changed, 62 insertions(+), 52 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 694f430..f47c4d7 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -27,46 +27,47 @@ #model_id = 'dys79lgw' # velocity_u #model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z - #model_id = '3qou60es' # temperature (also 2147fkco) + model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) # multi-field configurations with either velocity or voritcity+divergence #model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity - # model_id = '3cizyl1q' # 3 field config: u,v,T + #model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting # model_id = '1m79039j' # pre-trained, 6h forecasting - + model_id='ezi4shmb' # supported modes: test, forecast, fixed_location, temporal_interpolation, global_forecast, # global_forecast_range # options can be used to over-write parameters in config; some modes also have specific options, # e.g. global_forecast where a start date can be specified # BERT masked token model - #mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} + mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} # BERT forecast mode #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } - #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } + #temporal interpolation #idx_time_mask: list of relative time positions of the masked tokens within the cube wrt num_tokens[0] #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'attention' : False } # BERT forecast with patching to obtain global forecast - mode, options = 'global_forecast', { - 'dates' : [[2021, 1, 10, 18]], - # 'dates' : [ #[2021, 1, 10, 18] - # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], - # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], - # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], - # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] - # ], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 2, - 'attention' : False } +# mode, options = 'global_forecast', { +# 'dates' : [[2021, 1, 10, 18]], +# # 'dates' : [ #[2021, 1, 10, 18] +# # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], +# # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], +# # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], +# # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] +# # ], +# 'token_overlap' : [0, 0], +# 'forecast_num_tokens' : 2, +# 'attention' : False, +# 'with_pytest' : True } file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' now = time.time() Evaluator.evaluate( mode, model_id, file_path, options) - print("time", time.time() - now) + print("time", time.time() - now) \ No newline at end of file diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 687c62a..1177f0f 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -17,7 +17,7 @@ import numpy as np import os import code - +import pytest import datetime import wandb @@ -85,8 +85,8 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : #devices = ['cuda'] par_rank, par_size = setup_ddp( with_ddp) - cf.file_path = file_path cf = Config().load_json( model_id) + cf.file_path = file_path cf.with_wandb = True cf.with_ddp = with_ddp cf.par_rank = par_rank @@ -110,8 +110,15 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : cf.num_samples_per_epoch = 1024 if not hasattr(cf, 'with_mixed_precision'): cf.with_mixed_precision = False + if not hasattr(cf, 'with_pytest'): + cf.pytest = False func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) + + if cf.with_pytest: + fields = [field[0] for field in cf.fields_prediction] + for field in fields: + pytest.main(["-x", "./atmorep/tests/validation_test.py", "--field", field, "--model_id", cf.wandb_id]) ############################################## @staticmethod diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index 011f566..a63c3d8 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -17,8 +17,6 @@ import numpy as np import xarray as xr import zarr -import code -import datetime import atmorep.config.config as config def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): @@ -53,7 +51,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, sources_coords[fidx][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, sources_coords[fidx][bidx]) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -63,7 +61,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -73,7 +71,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -83,7 +81,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_ens.close() #################################################################################################### @@ -111,7 +109,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] ) + write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] ) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -125,7 +123,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sample = batch_idx * batch_size + bidx ds_target_b = ds_field.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_target_b_l = write_item(ds_target_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) + write_item(ds_target_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -139,7 +137,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sample = batch_idx * batch_size + bidx ds_pred_b = ds_pred.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_pred_b_l = write_item(ds_pred_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], + write_item(ds_pred_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) store_pred.close() @@ -154,7 +152,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sample = batch_idx * batch_size + bidx ds_ens_b = ds_ens.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_ens_b_l = write_item(ds_ens_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], + write_item(ds_ens_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) store_ens.close() diff --git a/atmorep/tests/test_utils.py b/atmorep/tests/test_utils.py index bf6d5ed..6552601 100644 --- a/atmorep/tests/test_utils.py +++ b/atmorep/tests/test_utils.py @@ -15,6 +15,8 @@ def grib_index(field): "total_precip": "tp", "velocity_v": "v", "velocity_z": "z", + "vorticity" : "vo", + "divergence" : "d", "specific_humidity": "q"} return grib_idxs[field] @@ -47,6 +49,8 @@ def get_max_RMSE(field): "velocity_u" : 0.005, #???? "velocity_v": 0.005, #???? "velocity_z": 0.005, #???? + "vorticity" : 0.2, #???? + "divergence": 0.2, #???? "specific_humidity": 0.7, #???? "total_precip": 9999, #????? } diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py index 88ce98d..f0156e9 100644 --- a/atmorep/tests/validation_test.py +++ b/atmorep/tests/validation_test.py @@ -30,35 +30,35 @@ def test_datetime(field, model_id, epoch = 0): """ Check against ERA5 timestamps. - 50 random samples. + Loop over all levels individually. 50 random samples for each level. """ - level = 137 store = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) atmorep = zarr.group(store) - #TODO: make it more elegant - ml_idx = np.where(atmorep[f"{field}/sample=00000"].ml[:] == level)[0].tolist()[0] - nsamples = min(len(atmorep[field]), 50) samples = rnd.sample(range(len(atmorep[field])), nsamples) - - for s in samples: - - data = atmorep[f"{field}/sample={s:05d}"].data[ml_idx, 0] - datetime = pd.Timestamp(atmorep[f"{field}/sample={s:05d}"].datetime[0]) - lats = atmorep[f"{field}/sample={s:05d}"].lat - lons = atmorep[f"{field}/sample={s:05d}"].lon - - year, month = datetime.year, str(datetime.month).zfill(2) - - era5_path = era5_fname().format(field, level, field, year, month, level) - if not os.path.isfile(era5_path): - warnings.warn(UserWarning((f"Timestamp {datetime} not found in ERA5. Skipping"))) - continue - era5 = xr.open_dataset(era5_path, engine = "cfgrib")[grib_index(field)].sel(time = datetime, latitude = lats, longitude = lons) - - assert (data[0] == era5.values[0]).all(), "Mismatch between ERA5 and AtmoRep Timestamps" + levels = atmorep[f"{field}/sample=00000"].ml[:] + + for level in levels: + #TODO: make it more elegant + ml_idx = np.where(levels == level)[0].tolist()[0] + for s in samples: + + data = atmorep[f"{field}/sample={s:05d}"].data[ml_idx, 0] + datetime = pd.Timestamp(atmorep[f"{field}/sample={s:05d}"].datetime[0]) + lats = atmorep[f"{field}/sample={s:05d}"].lat + lons = atmorep[f"{field}/sample={s:05d}"].lon + + year, month = datetime.year, str(datetime.month).zfill(2) + + era5_path = era5_fname().format(field, level, field, year, month, level) + if not os.path.isfile(era5_path): + warnings.warn(UserWarning((f"Timestamp {datetime} not found in ERA5. Skipping"))) + continue + era5 = xr.open_dataset(era5_path, engine = "cfgrib")[grib_index(field)].sel(time = datetime, latitude = lats, longitude = lons) + + assert (data[0] == era5.values[0]).all(), "Mismatch between ERA5 and AtmoRep Timestamps" ############################################################################# diff --git a/setup.py b/setup.py index 38e5bb4..f08feed 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ description='AtmoRep', packages=find_packages(), # if packages are available in a native form fo the host system then these should be used - #install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo'], + install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo', 'pytest', 'cfgrib'], data_files=[('./output', []), ('./logs', []), ('./results',[])], ) From 1f9489e3f869f6e645d1d8b3a481f5bca8efbf52 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 25 Jul 2024 19:01:24 +0200 Subject: [PATCH 48/66] tests now working in BERT, forecast and global_forecast mode --- atmorep/core/evaluate.py | 13 ++++--- atmorep/core/evaluator.py | 14 +++---- atmorep/tests/conftest.py | 3 ++ atmorep/tests/test_utils.py | 21 +++++++++- atmorep/tests/validation_test.py | 67 +++++++++++++++++--------------- 5 files changed, 72 insertions(+), 46 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index f47c4d7..1dcd729 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -29,28 +29,30 @@ # model_id = '15oisw8d' # velocity_z model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) - + # multi-field configurations with either velocity or voritcity+divergence #model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity #model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting # model_id = '1m79039j' # pre-trained, 6h forecasting - model_id='ezi4shmb' + model_id='s3wwcc3j' # supported modes: test, forecast, fixed_location, temporal_interpolation, global_forecast, # global_forecast_range # options can be used to over-write parameters in config; some modes also have specific options, # e.g. global_forecast where a start date can be specified + + #Add 'attention' : True to options to store the attention maps. NB. supported only for single field runs. # BERT masked token model - mode, options = 'BERT', {'years_test' : [2021], 'attention' : False} + mode, options = 'BERT', {'years_test' : [2021], 'num_samples_validate' : 128, 'with_pytest' : True } # BERT forecast mode - #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'attention' : False } + #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'num_samples_validate' : 10, 'with_pytest' : True } #temporal interpolation #idx_time_mask: list of relative time positions of the masked tokens within the cube wrt num_tokens[0] - #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'attention' : False } + #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'num_samples_validate' : 10, 'with_pytest' : True} # BERT forecast with patching to obtain global forecast # mode, options = 'global_forecast', { @@ -63,7 +65,6 @@ # # ], # 'token_overlap' : [0, 0], # 'forecast_num_tokens' : 2, -# 'attention' : False, # 'with_pytest' : True } file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index 1177f0f..a83fcce 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -111,14 +111,15 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : if not hasattr(cf, 'with_mixed_precision'): cf.with_mixed_precision = False if not hasattr(cf, 'with_pytest'): - cf.pytest = False + cf.with_pytest = False + func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) if cf.with_pytest: fields = [field[0] for field in cf.fields_prediction] for field in fields: - pytest.main(["-x", "./atmorep/tests/validation_test.py", "--field", field, "--model_id", cf.wandb_id]) + pytest.main(["-x", "./atmorep/tests/validation_test.py", "--field", field, "--model_id", cf.wandb_id, "--strategy", cf.BERT_strategy]) ############################################## @staticmethod @@ -127,8 +128,7 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : cf.lat_sampling_weighted = False cf.BERT_strategy = 'BERT' cf.log_test_num_ranks = 4 - if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 128 #1472 + cf.num_samples_validate = 128 #1472 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) @@ -140,8 +140,7 @@ def forecast( cf, model_id, model_epoch, devices, args = {}) : cf.BERT_strategy = 'forecast' cf.log_test_num_ranks = 4 cf.forecast_num_tokens = 1 # will be overwritten when user specified - if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 128 + cf.num_samples_validate = 128 #128 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) @@ -218,8 +217,7 @@ def temporal_interpolation( cf, model_id, model_epoch, devices, args = {}) : # set/over-write options cf.BERT_strategy = 'temporal_interpolation' cf.log_test_num_ranks = 4 - if not hasattr(cf, 'num_samples_validate'): - cf.num_samples_validate = 128 + cf.num_samples_validate = 10 #128 Evaluator.parse_args( cf, args) Evaluator.run( cf, model_id, model_epoch, devices) diff --git a/atmorep/tests/conftest.py b/atmorep/tests/conftest.py index 248e16e..62daba8 100644 --- a/atmorep/tests/conftest.py +++ b/atmorep/tests/conftest.py @@ -2,4 +2,7 @@ def pytest_addoption(parser): parser.addoption("--field", action="store", help="field to run the test on") parser.addoption("--model_id", action="store", help="wandb ID of the atmorep model") parser.addoption("--epoch", action="store", help="field to run the test on", default = "0") + parser.addoption("--strategy", action="store", help="BERT or forecast") + + diff --git a/atmorep/tests/test_utils.py b/atmorep/tests/test_utils.py index 6552601..9f3cfd4 100644 --- a/atmorep/tests/test_utils.py +++ b/atmorep/tests/test_utils.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd def era5_fname(): return "/gpfs/scratch/ehpc03/data/{}/ml{}/era5_{}_y{}_m{}_ml{}.grib" @@ -21,6 +22,24 @@ def grib_index(field): return grib_idxs[field] +################################################################## + +def get_BERT(atmorep, field, sample, level): + atmorep_sample = atmorep[f"{field}/sample={sample:05d}/ml={level:05d}"] + data = atmorep_sample.data[0,0] + datetime = pd.Timestamp(atmorep_sample.datetime[0,0]) + lats = atmorep_sample.lat[0] + lons = atmorep_sample.lon[0] + return data, datetime, lats, lons + +def get_forecast(atmorep, field, sample,level_idx): + atmorep_sample = atmorep[f"{field}/sample={sample:05d}"] + data = atmorep_sample.data[level_idx, 0] + datetime = pd.Timestamp(atmorep_sample.datetime[0]) + lats = atmorep_sample.lat + lons = atmorep_sample.lon + return data, datetime, lats, lons + ###################################### def check_lats(lats_pred, lats_target): @@ -45,7 +64,7 @@ def compute_RMSE(pred, target): def get_max_RMSE(field): #TODO: optimize thresholds - values = {"temperature" : 1.8, + values = {"temperature" : 3, "velocity_u" : 0.005, #???? "velocity_v": 0.005, #???? "velocity_z": 0.005, #???? diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py index f0156e9..3a1656e 100644 --- a/atmorep/tests/validation_test.py +++ b/atmorep/tests/validation_test.py @@ -2,7 +2,6 @@ import zarr import cfgrib import xarray as xr -import pandas as pd import numpy as np import random as rnd import warnings @@ -10,7 +9,7 @@ from atmorep.tests.test_utils import * -# run it with e.g. pytest -s atmorep/tests/validation_test.py --field temperature --model_id ztsut0mr +# run it with e.g. pytest -s atmorep/tests/validation_test.py --field temperature --model_id ztsut0mr --strategy BERT @pytest.fixture def field(request): @@ -20,13 +19,16 @@ def field(request): def model_id(request): return request.config.getoption("model_id") -@pytest.fixture +@pytest.fixture def epoch(request): request.config.getoption("epoch") -################################################################## +@pytest.fixture(autouse = True) +def BERT(request): + strategy = request.config.getoption("strategy") + return (strategy == 'BERT' or strategy == 'temporal_interpolation') -def test_datetime(field, model_id, epoch = 0): +def test_datetime(field, model_id, BERT, epoch = 0): """ Check against ERA5 timestamps. @@ -38,18 +40,16 @@ def test_datetime(field, model_id, epoch = 0): nsamples = min(len(atmorep[field]), 50) samples = rnd.sample(range(len(atmorep[field])), nsamples) - levels = atmorep[f"{field}/sample=00000"].ml[:] + levels = [int(f.split("=")[1]) for f in atmorep[f"{field}/sample=00000"]] if BERT else atmorep[f"{field}/sample=00000"].ml[:] + + get_data = get_BERT if BERT else get_forecast for level in levels: #TODO: make it more elegant - ml_idx = np.where(levels == level)[0].tolist()[0] - for s in samples: - - data = atmorep[f"{field}/sample={s:05d}"].data[ml_idx, 0] - datetime = pd.Timestamp(atmorep[f"{field}/sample={s:05d}"].datetime[0]) - lats = atmorep[f"{field}/sample={s:05d}"].lat - lons = atmorep[f"{field}/sample={s:05d}"].lon + level_idx = level if BERT else np.where(levels == level)[0].tolist()[0] + for s in samples: + data, datetime, lats, lons = get_data(atmorep, field, s, level_idx) year, month = datetime.year, str(datetime.month).zfill(2) era5_path = era5_fname().format(field, level, field, year, month, level) @@ -62,7 +62,7 @@ def test_datetime(field, model_id, epoch = 0): ############################################################################# -def test_coordinates(field, model_id, epoch = 0): +def test_coordinates(field, model_id, BERT, epoch = 0): """ Check that coordinates match between target and prediction. Check also that latitude and longitudes are in geographical coordinates @@ -77,23 +77,23 @@ def test_coordinates(field, model_id, epoch = 0): nsamples = min(len(target[field]), 50) samples = rnd.sample(range(len(target[field])), nsamples) - - for s in samples: - datetime_target = [pd.Timestamp(i) for i in target[f"{field}/sample={s:05d}"].datetime] - lats_target = target[f"{field}/sample={s:05d}"].lat - lons_target = target[f"{field}/sample={s:05d}"].lon + levels = [int(f.split("=")[1]) for f in target[f"{field}/sample=00000"]] if BERT else target[f"{field}/sample=00000"].ml[:] + + get_data = get_BERT if BERT else get_forecast - datetime_pred = [pd.Timestamp(i) for i in pred[f"{field}/sample={s:05d}"].datetime] - lats_pred = pred[f"{field}/sample={s:05d}"].lat - lons_pred = pred[f"{field}/sample={s:05d}"].lon + for level in levels: + level_idx = level if BERT else np.where(levels == level)[0].tolist()[0] + for s in samples: + _, datetime_target, lats_target, lons_target = get_data(target,field, s, level_idx) + _, datetime_pred, lats_pred, lons_pred = get_data(pred, field, s, level_idx) - check_lats(lats_pred, lats_target) - check_lons(lons_pred, lons_target) - check_datetimes(datetime_pred, datetime_target) + check_lats(lats_pred, lats_target) + check_lons(lons_pred, lons_target) + check_datetimes(datetime_pred, datetime_target) ######################################################################### -def test_rmse(field, model_id, epoch = 0): +def test_rmse(field, model_id, BERT, epoch = 0): """ Test that for each field the RMSE does not exceed a certain value. 50 random samples. @@ -103,12 +103,17 @@ def test_rmse(field, model_id, epoch = 0): store_p = zarr.ZipStore(atmorep_pred().format(model_id, model_id, str(epoch).zfill(5))) pred = zarr.group(store_p) - + nsamples = min(len(target[field]), 50) samples = rnd.sample(range(len(target[field])), nsamples) + levels = [int(f.split("=")[1]) for f in target[f"{field}/sample=00000"]] if BERT else target[f"{field}/sample=00000"].ml[:] - for s in samples: - sample_target = target[f"{field}/sample={s:05d}"].data[:] - sample_pred = pred[f"{field}/sample={s:05d}"].data[:] + get_data = get_BERT if BERT else get_forecast + + for level in levels: + level_idx = level if BERT else np.where(levels == level)[0].tolist()[0] + for s in samples: + sample_target, _, _, _ = get_data(target,field, s, level_idx) + sample_pred, _, _, _ = get_data(pred,field, s, level_idx) - assert compute_RMSE(sample_target, sample_pred).mean() < get_max_RMSE(field) + assert compute_RMSE(sample_target, sample_pred).mean() < get_max_RMSE(field) From 3868ab41a354c82de5a5ac7c71877c46f8016278 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 25 Jul 2024 19:46:52 +0200 Subject: [PATCH 49/66] restore values in evaluate --- atmorep/core/evaluate.py | 11 ++++++----- atmorep/tests/test_utils.py | 10 +++++----- atmorep/tests/validation_test.py | 6 ++++++ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 1dcd729..8a2583f 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -24,19 +24,20 @@ #model_id = 'oxpycr7w' # divergence #model_id = '1565pb1f' # specific_humidity #model_id = '3kdutwqb' # total precip - #model_id = 'dys79lgw' # velocity_u + model_id = 'dys79lgw' # velocity_u #model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z - model_id = '3qou60es' # temperature (also 2147fkco) + #model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) - + #model_id='s3wwcc3j' + # multi-field configurations with either velocity or voritcity+divergence #model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity #model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting # model_id = '1m79039j' # pre-trained, 6h forecasting - model_id='s3wwcc3j' + #model_id='34niv2nu' # supported modes: test, forecast, fixed_location, temporal_interpolation, global_forecast, # global_forecast_range # options can be used to over-write parameters in config; some modes also have specific options, @@ -45,7 +46,7 @@ #Add 'attention' : True to options to store the attention maps. NB. supported only for single field runs. # BERT masked token model - mode, options = 'BERT', {'years_test' : [2021], 'num_samples_validate' : 128, 'with_pytest' : True } + mode, options = 'BERT', {'years_test' : [2021], 'num_samples_validate' : 10, 'with_pytest' : True } # BERT forecast mode #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'num_samples_validate' : 10, 'with_pytest' : True } diff --git a/atmorep/tests/test_utils.py b/atmorep/tests/test_utils.py index 9f3cfd4..60176ee 100644 --- a/atmorep/tests/test_utils.py +++ b/atmorep/tests/test_utils.py @@ -65,13 +65,13 @@ def compute_RMSE(pred, target): def get_max_RMSE(field): #TODO: optimize thresholds values = {"temperature" : 3, - "velocity_u" : 0.005, #???? - "velocity_v": 0.005, #???? - "velocity_z": 0.005, #???? + "velocity_u" : 0.2, #???? + "velocity_v": 0.2, #???? + "velocity_z": 0.2, #???? "vorticity" : 0.2, #???? "divergence": 0.2, #???? - "specific_humidity": 0.7, #???? - "total_precip": 9999, #????? + "specific_humidity": 0.2, #???? + "total_precip": 1, #????? } return values[field] \ No newline at end of file diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py index 3a1656e..e4e30d2 100644 --- a/atmorep/tests/validation_test.py +++ b/atmorep/tests/validation_test.py @@ -28,6 +28,12 @@ def BERT(request): strategy = request.config.getoption("strategy") return (strategy == 'BERT' or strategy == 'temporal_interpolation') +@pytest.fixture(autouse = True) +def strategy(request): + return request.config.getoption("strategy") + +#TODO: add test for global_forecast vs ERA5 + def test_datetime(field, model_id, BERT, epoch = 0): """ From 7863acc3482fb4f931bba221624f30216caf47d7 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 26 Jul 2024 10:05:23 +0000 Subject: [PATCH 50/66] Reenabled packages. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 38e5bb4..012c5b1 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ description='AtmoRep', packages=find_packages(), # if packages are available in a native form fo the host system then these should be used - #install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo'], + install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo'], data_files=[('./output', []), ('./logs', []), ('./results',[])], ) From 4c6111cd151fb50c4c2db00817a3daf424a5fd23 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 26 Jul 2024 10:06:53 +0000 Subject: [PATCH 51/66] Added version requirement for torch. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 065d97f..522b71a 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ description='AtmoRep', packages=find_packages(), # if packages are available in a native form fo the host system then these should be used - install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo', 'pytest', 'cfgrib'], + install_requires=['torch>=2.3', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo', 'pytest', 'cfgrib'], >>>>>>> 3868ab41a354c82de5a5ac7c71877c46f8016278 data_files=[('./output', []), ('./logs', []), ('./results',[])], ) From 56a6a39d9c0f5f526e9e3429e9bc161eb35204a5 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 10:03:29 +0000 Subject: [PATCH 52/66] Factored MLP into a separate file. --- atmorep/transformer/mlp.py | 67 +++++++++++++++++ atmorep/transformer/transformer.py | 6 +- atmorep/transformer/transformer_attention.py | 4 - atmorep/transformer/transformer_base.py | 78 +------------------- atmorep/transformer/transformer_decoder.py | 5 +- atmorep/transformer/transformer_encoder.py | 3 +- 6 files changed, 73 insertions(+), 90 deletions(-) create mode 100644 atmorep/transformer/mlp.py diff --git a/atmorep/transformer/mlp.py b/atmorep/transformer/mlp.py new file mode 100644 index 0000000..5908650 --- /dev/null +++ b/atmorep/transformer/mlp.py @@ -0,0 +1,67 @@ + +import torch + +from atmorep.utils.utils import identity +from atmorep.transformer.transformer_base import checkpoint_wrapper + +#################################################################################################### +class MLP(torch.nn.Module): + + def __init__(self, dim_embed, num_layers = 2, with_lnorm = True, dim_embed_out = None, + nonlin = torch.nn.GELU(), dim_internal_factor = 2, dropout_rate = 0., + grad_checkpointing = False, with_residual = True) : + """ + Multi-layer perceptron + + dim_embed : embedding dimension + num_layers : number of layers + nonlin : nonlinearity + dim_internal_factor : factor for number of hidden dimension relative to input / output + """ + super(MLP, self).__init__() + + if not dim_embed_out : + dim_embed_out = dim_embed + + self.with_residual = with_residual + + dim_internal = int( dim_embed * dim_internal_factor) + if with_lnorm : + self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) + else : + self.lnorm = torch.nn.Identity() + + self.blocks = torch.nn.ModuleList() + self.blocks.append( torch.nn.Linear( dim_embed, dim_internal)) + self.blocks.append( nonlin) + self.blocks.append( torch.nn.Dropout( p = dropout_rate)) + + for _ in range( num_layers-2) : + self.blocks.append( torch.nn.Linear( dim_internal, dim_internal)) + self.blocks.append( nonlin) + self.blocks.append( torch.nn.Dropout( p = dropout_rate)) + + self.blocks.append( torch.nn.Linear( dim_internal, dim_embed_out)) + self.blocks.append( nonlin) + + if dim_embed == dim_embed_out : + self.proj_residual = torch.nn.Identity() + else : + self.proj_residual = torch.nn.Linear( dim_embed, dim_embed_out) + + self.checkpoint = identity + if grad_checkpointing : + self.checkpoint = checkpoint_wrapper + + def forward( self, x, y = None) : + + x_in = x + x = self.lnorm( x) + + for block in self.blocks: + x = self.checkpoint( block, x) + + if self.with_residual : + x += x_in + + return x \ No newline at end of file diff --git a/atmorep/transformer/transformer.py b/atmorep/transformer/transformer.py index 5793ae7..c3f5ee9 100644 --- a/atmorep/transformer/transformer.py +++ b/atmorep/transformer/transformer.py @@ -15,12 +15,12 @@ #################################################################################################### import torch -import numpy as np -import math -from atmorep.transformer.transformer_base import MLP, prepare_token +from atmorep.transformer.mlp import MLP +from atmorep.transformer.transformer_base import prepare_token from atmorep.transformer.transformer_attention import MultiSelfAttentionHead + class Transformer(torch.nn.Module) : def __init__(self, num_layers, dim_input, dim_embed = 2048, num_heads = 8, num_mlp_layers = 2, diff --git a/atmorep/transformer/transformer_attention.py b/atmorep/transformer/transformer_attention.py index e312ea4..7aa99a3 100644 --- a/atmorep/transformer/transformer_attention.py +++ b/atmorep/transformer/transformer_attention.py @@ -15,12 +15,8 @@ #################################################################################################### import torch -import numpy as np -import math from enum import Enum -import code -from atmorep.transformer.axial_attention import AxialAttention from atmorep.utils.utils import identity from atmorep.transformer.transformer_base import checkpoint_wrapper diff --git a/atmorep/transformer/transformer_base.py b/atmorep/transformer/transformer_base.py index 42f9b30..71c59d0 100644 --- a/atmorep/transformer/transformer_base.py +++ b/atmorep/transformer/transformer_base.py @@ -16,10 +16,6 @@ import torch import numpy as np -import math -import code - -from atmorep.utils.utils import identity #################################################################################################### @@ -95,68 +91,6 @@ def positional_encoding_harmonic( x, num_levels, num_tokens, with_cls = False) : return x -#################################################################################################### -class MLP(torch.nn.Module): - - def __init__(self, dim_embed, num_layers = 2, with_lnorm = True, dim_embed_out = None, - nonlin = torch.nn.GELU(), dim_internal_factor = 2, dropout_rate = 0., - grad_checkpointing = False, with_residual = True) : - """ - Multi-layer perceptron - - dim_embed : embedding dimension - num_layers : number of layers - nonlin : nonlinearity - dim_internal_factor : factor for number of hidden dimension relative to input / output - """ - super(MLP, self).__init__() - - if not dim_embed_out : - dim_embed_out = dim_embed - - self.with_residual = with_residual - - dim_internal = int( dim_embed * dim_internal_factor) - if with_lnorm : - self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) - else : - self.lnorm = torch.nn.Identity() - - self.blocks = torch.nn.ModuleList() - self.blocks.append( torch.nn.Linear( dim_embed, dim_internal)) - self.blocks.append( nonlin) - self.blocks.append( torch.nn.Dropout( p = dropout_rate)) - - for _ in range( num_layers-2) : - self.blocks.append( torch.nn.Linear( dim_internal, dim_internal)) - self.blocks.append( nonlin) - self.blocks.append( torch.nn.Dropout( p = dropout_rate)) - - self.blocks.append( torch.nn.Linear( dim_internal, dim_embed_out)) - self.blocks.append( nonlin) - - if dim_embed == dim_embed_out : - self.proj_residual = torch.nn.Identity() - else : - self.proj_residual = torch.nn.Linear( dim_embed, dim_embed_out) - - self.checkpoint = identity - if grad_checkpointing : - self.checkpoint = checkpoint_wrapper - - def forward( self, x, y = None) : - - x_in = x - x = self.lnorm( x) - - for block in self.blocks: - x = self.checkpoint( block, x) - - if self.with_residual : - x += x_in - - return x - #################################################################################################### def prepare_token_info( cf, token_info) : @@ -173,7 +107,7 @@ def prepare_token_info( cf, token_info) : return token_info #################################################################################################### -def prepare_token( xin, embed, embed_token_info, with_cls = True) : +def prepare_token( xin, embed, embed_token_info) : (token_seq, token_info) = xin num_tokens = token_seq.shape[-6:-3] @@ -187,18 +121,8 @@ def prepare_token( xin, embed, embed_token_info, with_cls = True) : # token_info = prepare_token_info( cf, token_info) token_info = token_info.reshape([-1] + list(token_seq_embed.shape[1:-1])+[token_info.shape[-1]]) token_seq_embed = torch.cat( [token_seq_embed, token_info], -1) - - # class token - if with_cls : - # initialize to zero (mean of data) - tts = token_seq_embed.shape - cls_token = torch.zeros( (tts[0], 1, tts[2]), device=token_seq_embed.device) # add positional encoding token_seq_embed = positional_encoding_harmonic( token_seq_embed, num_levels, num_tokens) - # add class token after positional encoding - if with_cls : - token_seq_embed = torch.cat( [ cls_token, token_seq_embed ], 1) - return token_seq_embed diff --git a/atmorep/transformer/transformer_decoder.py b/atmorep/transformer/transformer_decoder.py index 11c622b..31720ba 100644 --- a/atmorep/transformer/transformer_decoder.py +++ b/atmorep/transformer/transformer_decoder.py @@ -15,11 +15,8 @@ #################################################################################################### import torch -import torch.utils.checkpoint as checkpoint -import numpy as np -import math -from atmorep.transformer.transformer_base import MLP, prepare_token +from atmorep.transformer.mlp import MLP from atmorep.transformer.transformer_attention import MultiSelfAttentionHead, MultiCrossAttentionHead from atmorep.transformer.axial_attention import MultiFieldAxialAttention from atmorep.utils.utils import identity diff --git a/atmorep/transformer/transformer_encoder.py b/atmorep/transformer/transformer_encoder.py index 054615e..f88c7f6 100644 --- a/atmorep/transformer/transformer_encoder.py +++ b/atmorep/transformer/transformer_encoder.py @@ -16,9 +16,8 @@ import torch import numpy as np -import math -from atmorep.transformer.transformer_base import MLP +from atmorep.transformer.mlp import MLP from atmorep.transformer.transformer_attention import MultiInterAttentionHead from atmorep.transformer.axial_attention import MultiFieldAxialAttention From 7d8e24cf99869de08b3ddad3401b9aedfbec1b2d Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 10:04:30 +0000 Subject: [PATCH 53/66] Adding logging code but still needs to be used. --- atmorep/utils/logger.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 atmorep/utils/logger.py diff --git a/atmorep/utils/logger.py b/atmorep/utils/logger.py new file mode 100644 index 0000000..2ab1811 --- /dev/null +++ b/atmorep/utils/logger.py @@ -0,0 +1,22 @@ + +import logging +import pathlib +import os + +class RelPathFormatter(logging.Formatter): + def __init__(self, fmt, datefmt=None): + super().__init__(fmt, datefmt) + self.root_path = pathlib.Path(__file__).parent.parent.parent.resolve() + + def format(self, record): + # Replace the full pathname with the relative path + record.pathname = os.path.relpath(record.pathname, self.root_path) + return super().format(record) + +logger = logging.getLogger('atmorep') +logger.setLevel(logging.DEBUG) +ch = logging.StreamHandler() +formatter = RelPathFormatter('%(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s') +ch.setFormatter(formatter) +logger.handlers.clear() +logger.addHandler(ch) From 9f69a5dd3edab4992c2251462dd95f3e707683aa Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 10:08:22 +0000 Subject: [PATCH 54/66] Removing used file and functionality (which has moved to multifield_data_sampler.py). --- atmorep/datasets/data_loader.py | 164 -------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 atmorep/datasets/data_loader.py diff --git a/atmorep/datasets/data_loader.py b/atmorep/datasets/data_loader.py deleted file mode 100644 index 9c2638f..0000000 --- a/atmorep/datasets/data_loader.py +++ /dev/null @@ -1,164 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import pathlib -import numpy as np -import xarray as xr -from functools import partial - -import atmorep.config.config as config -import atmorep.utils.utils as utils -from atmorep.config.config import year_base -from atmorep.utils.utils import tokenize -from atmorep.datasets.file_io import grib_file_loader, netcdf_file_loader, bin_file_loader - -# TODO, TODO, TODO: replace with torch functonality -# import cv2 as cv - -class DataLoader: - - def __init__(self, path, file_shape, data_type, field_info, - file_format = 'grib', level_type = 'pl', - fname_base = '{}/{}/{}/{}{}/{}_{}_y{}_m{}_{}{}', - smoothing = 0, - log_transform = False, - partial_load = 0): - - self.path = path - self.data_type = data_type - self.field_info = field_info - self.file_format = file_format - self.file_shape = file_shape - self.fname_base = fname_base - self.smoothing = smoothing - self.log_transform = log_transform - self.partial_load = partial_load - - if 'grib' == file_format : - self.file_ext = '.grib' - self.file_loader = grib_file_loader - elif 'binary' == file_format : - self.file_ext = '_fp32.dat' - self.file_loader = bin_file_loader - elif 'netcdf4' == file_format : - self.file_ext = '.nc4' - self.file_loader = netcdf_file_loader - elif 'netcdf' == file_format : - self.file_ext = '.nc' - self.file_loader = netcdf_file_loader - else : - raise ValueError('Unsupported file format.') - - self.fname_base = fname_base + self.file_ext - - self.grib_index = config.grib_index - - def get_field( self, year, month, field, level_type, vl, - token_size = [-1, -1], t_pad = [-1, -1, 1]): - - t_srate = t_pad[2] - data_ym = torch.zeros( (0, self.file_shape[1], self.file_shape[2])) - - # pre-fill fixed values - fname_base = self.fname_base.format( self.path, self.data_type, field, level_type, vl, - self.data_type, field, {},{},{},{}) - - # padding pre - if t_pad[0] > 0 : - if month > 1: - month_p = str(month-1).zfill(2) - days_month = utils.days_in_month( year, month-1) - fname = fname_base.format( year, month_p, level_type, vl) - else: - assert(year >= year_base) - year_p = str(year-1).zfill(2) - days_month = utils.days_in_month( year, 12) - fname = fname_base.format( year-1, 12, level_type, vl) - x = self.file_loader( fname, self.grib_index[field], [t_pad[0], 0, t_srate], - days_month ) - - data_ym = torch.cat((data_ym,x),0) - - # data - fname = fname_base.format( year, str(month).zfill(2), level_type, vl) - days_month = utils.days_in_month( year, month) - x = self.file_loader(fname,self.grib_index[field], [0, self.partial_load, t_srate],days_month) - - data_ym = torch.cat((data_ym,x),0) - - # padding post - if t_pad[1] > 0 : - if month > 1: - month_p = str(month+1).zfill(2) - days_month = utils.days_in_month( year, month+1) - fname = fname_base.format( year, month_p, level_type, vl) - else: - assert(year >= year_base) - year_p = str(year+1).zfill(2) - days_month = utils.days_in_month( year+1, 12) - fname = fname_base.format( year_p, 12, level_type, vl) - x = self.file_loader( fname, self.grib_index[field], [0, t_pad[1], t_srate], - days_month) - - data_ym = torch.cat((data_ym,x),0) - - if self.smoothing > 0 : - sm = self.smoothing - mask_nan = torch.isnan( data_ym) - data_ym[ mask_nan ] = 0. - blur = partial( cv.blur, ksize=(sm,sm), borderType=cv.BORDER_REFLECT_101) - data_ym = [torch.from_numpy( blur( data_ym[k].numpy()) ).unsqueeze(0) - for k in range(data_ym.shape[0])] - data_ym = torch.cat( data_ym, 0) - data_ym[ mask_nan ] = torch.nan - - # tokenize - data_ym = tokenize( data_ym, token_size) - - return data_ym - - def get_single_field( self, years_months, field = 'vorticity', level_type = 'pl', vl = 975, - token_size = [-1, -1], t_pad = [-1, -1, 1]): - - data_field = [] - extent_t = config.datasets[self.data_type]['extent'][0] - for year, month in years_months : - # skip loading when the year is not available for the dataset - if year < extent_t[0] or year > extent_t[1] : - data_field.append( []) - continue - data_field.append( self.get_field( year, month, field, level_type, vl, token_size, t_pad)) - - return data_field - - def get_static_field( self, field, token_size = [-1, -1]): - - #support for static fields from other data types - data_type = self.data_type - f = self.path + '/' + data_type + '/static/' +self.data_type + '_' + field + self.file_ext - - x = self.file_loader(f, self.grib_index[field], static=True) - - if self.smoothing > 0 : - sm = self.smoothing - blur = partial( cv.blur, ksize=(sm,sm), borderType=cv.BORDER_REFLECT_101) - # x = torch.from_numpy( cv.blur( x.numpy(), (self.smoothing,self.smoothing))) - x = torch.from_numpy( blur( x.numpy() )) - - x = tokenize( x, token_size) - - return x From c8b180da780e6e2e9ee06d4749d32fb291aee60a Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 10:09:28 +0000 Subject: [PATCH 55/66] - Added optimization suggested by BSC (although no performance improvement can be seen on ECMWF-ATOS) - Minor clenaup. --- atmorep/datasets/multifield_data_sampler.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index b34d063..eb5a14c 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -21,6 +21,7 @@ from datetime import datetime import time import os +import code # from atmorep.datasets.normalizer_global import NormalizerGlobal # from atmorep.datasets.normalizer_local import NormalizerLocal @@ -75,6 +76,7 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, # lon: no change for periodic case if self.ds_global < 1.: self.range_lon += np.array([n_size[2]/2., -n_size[2]/2.]) + # data normalizers self.normalizers = [] for ifield, field_info in enumerate(fields) : @@ -84,19 +86,20 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, for vl in field_info[2]: if vl == 0: field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) - self.normalizers[ifield] += [self.ds[f'normalization/{nf_name}_sfc'].oindex[ :, :, field_idx]] + n_name = f'normalization/{nf_name}_sfc' + self.normalizers[ifield] += [self.ds[n_name].oindex[ :, :, field_idx]] else: vl_idx = self.ds.attrs['levels'].index(vl) field_idx = self.ds.attrs['fields'].index( field_info[0]) - self.normalizers[ifield] += [self.ds[f'normalization/{nf_name}'].oindex[ :, :, field_idx, vl_idx]] + n_name = f'normalization/{nf_name}' + self.normalizers[ifield] += [self.ds[n_name].oindex[ :, :, field_idx, vl_idx]] + # extract indices for selected years self.times = pd.DatetimeIndex( self.ds['time']) idxs_years = self.times.year == years[0] for year in years[1:] : idxs_years = np.logical_or( idxs_years, self.times.year == year) self.idxs_years = np.where( idxs_years)[0] - # logging.getLogger('atmorep').info( f'Dataset size for years {years}: {len(self.idxs_years)}.') - print( f'Dataset size for years {years}: {len(self.idxs_years)}.', flush=True) ################################################### def shuffle( self) : @@ -178,17 +181,19 @@ def __iter__(self): source_data, tok_info = [], [] # extract data, normalize and tokenize - cdata = np.take( np.take( data_t, lat_ran, -2), lon_ran, -1) + cdata = data_t[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] normalizer = self.normalizers[ifield][ilevel] if corr_type != 'global': - normalizer = np.take( np.take( normalizer, lat_ran, -2), lon_ran, -1) + normalizer = normalizer[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) + source_data = tokenize( torch.from_numpy( cdata), tok_size ) # token_infos uses center of the token: *last* datetime and center in space dates = self.ds['time'][ idxs_t ].astype(datetime) cdates = dates[tok_size[0]-1::tok_size[0]] - dates = [(d.year, d.timetuple().tm_yday-1, d.hour) for d in cdates] #-1 is to start days from 0 + # use -1 is to start days from 0 + dates = [(d.year, d.timetuple().tm_yday-1, d.hour) for d in cdates] lats_sidx = self.lats[lat_ran][ tok_size[1]//2 :: tok_size[1] ] lons_sidx = self.lons[lon_ran][ tok_size[2]//2 :: tok_size[2] ] # tensor product for token_infos From 0dd99a7536b39a90c6803519e74d23db9bf3a0a8 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 10:10:22 +0000 Subject: [PATCH 56/66] Cleaned up unused parameters and dependencies. --- atmorep/core/atmorep_model.py | 23 +---- atmorep/core/train.py | 166 ++++++++++++++-------------------- atmorep/core/trainer.py | 12 +-- atmorep/utils/utils.py | 4 +- 4 files changed, 78 insertions(+), 127 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 06c4843..4673fcc 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -137,11 +137,9 @@ def mode( self, mode : NetMode) : if mode == NetMode.train : self.data_loader_iter = iter(self.data_loader_train) - #self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : self.data_loader_iter = iter(self.data_loader_test) - #self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -188,11 +186,12 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, cf.batch_size, pre_batch, cf.n_size, cf.num_samples_per_epoch, - with_shuffle = (cf.BERT_strategy != 'global_forecast'), with_source_idxs = True ) + with_shuffle = (cf.BERT_strategy != 'global_forecast'), + with_source_idxs = True ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) - self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_test, + self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val, cf.batch_size_validation, pre_batch, cf.n_size, cf.num_samples_validate, with_shuffle = (cf.BERT_strategy != 'global_forecast'), @@ -250,17 +249,9 @@ def create( self, devices, load_pretrained=True) : self.embeds = torch.nn.ModuleList() self.encoders = torch.nn.ModuleList() - self.masks = torch.nn.ParameterList() for field_idx, field_info in enumerate(cf.fields) : - # learnabl class token - if cf.learnable_mask : - mask = torch.nn.Parameter( 0.1 * torch.randn( np.prod( field_info[4]), requires_grad=True)) - self.masks.append( mask.to(devices[0])) - else : - self.masks.append( None) - # encoder self.encoders.append( TransformerEncoder( cf, field_idx, True).create()) # load pre-trained model if specified @@ -316,8 +307,6 @@ def create( self, devices, load_pretrained=True) : assert field_info[1][3] < len(devices), 'Per field device id larger than max devices' device = self.devices[ field_info[1][3] ] # set device - if self.masks[field_idx] != None : - self.masks[field_idx].to(device) self.embeds[field_idx].to(device) self.encoders[field_idx].to(device) @@ -551,16 +540,14 @@ def forward_encoder_block( self, iblock, fields_embed) : return fields_embed_cur, atts ################################################### - def get_fields_embed( self, xin ) : - cf = self.cf if 0 == len(self.embeds_token_info) : # TODO: only for backward compatibility, remove emb_net_ti = self.embed_token_info - return [prepare_token( field_data, emb_net, emb_net_ti, cf.with_cls ) + return [prepare_token( field_data, emb_net, emb_net_ti ) for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))] else : embs_net_ti = self.embeds_token_info - return [prepare_token( field_data, emb_net, embs_net_ti[fidx], cf.with_cls ) + return [prepare_token( field_data, emb_net, embs_net_ti[fidx] ) for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))] ################################################### diff --git a/atmorep/core/train.py b/atmorep/core/train.py index be33954..2f9bcc6 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -17,6 +17,9 @@ import torch import numpy as np import os +import sys +import pdb +import traceback import wandb @@ -34,7 +37,6 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) - #device = ['cuda'] with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -45,14 +47,6 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : cf.par_size = par_size cf.optimizer_zero = False cf.attention = False - - cf.batch_size = 96 #16 #4 # 32 - cf.lr_max = 0.00005*3 - cf.num_samples_per_epoch = 4096*12 - cf.num_samples_validate = 128*12 - - cf.losses = ['weighted_mse', 'stats'] - # name has changed but ensure backward compatibility if hasattr( cf, 'loader_num_workers') : cf.num_loader_workers = cf.loader_num_workers @@ -64,9 +58,8 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : cf.num_samples_validate = 128 if not hasattr(cf, 'with_mixed_precision'): cf.with_mixed_precision = True - - # cf.years_train = [2021] # list( range( 1980, 2018)) - # cf.years_test = [2021] #[2018] + if not hasattr(cf, 'years_val'): + cf.years_val = cf.years_test # any parameter in cf can be overwritten when training is continued, e.g. we can increase the # masking rate @@ -95,8 +88,6 @@ def train() : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) - #device = ['cuda'] - with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -109,11 +100,6 @@ def train() : cf.num_accs_per_task = num_accs_per_task # number of GPUs / accelerators per task cf.par_rank = par_rank cf.par_size = par_size - cf.back_passes_per_step = 4 - # general - cf.comment = '' - cf.file_format = 'grib' - cf.level_type = 'ml' # format: list of fields where for each field the list is # [ name , @@ -124,61 +110,44 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'specific_humidity', [ 1, 1024, [ ], 0 ], - [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local'] ] - [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] + cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - # cf.fields = [ [ 'temperature',[ 1, 512, [ ], 0 ], + # cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0], # [ 96, 105, 114, 123, 137 ], - # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.2, 0.05] ] ] - - # cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], + # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ] ] + + # cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] + # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - - # cf.fields = [ [ 'velocity_u', [ 1, 2048, ['velocity_v', 'temperature'], 0 ], + # cf.fields = [ [ 'velocity_z', [ 1, 1024, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - # [ 'velocity_v', [ 1, 2048, ['velocity_u', 'temperature'], 1 ], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + + # cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - # [ 'specific_humidity', [ 1, 2048, ['velocity_u', 'velocity_v', 'temperature'], 2 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - # [ 'velocity_z', [ 1, 1024, ['velocity_u', 'velocity_v', 'temperature'], 3 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'global' ], - # [ 'temperature', [ 1, 1024, ['velocity_u', 'velocity_v', 'specific_humidity'], 3 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.2, 0.05], 'local' ], - # ['total_precip', [1, 1536, ['velocity_u', 'velocity_v', 'velocity_z', 'specific_humidity'], 3], - # [0], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05]] ] + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ] cf.fields_targets = [] - - cf.years_train = list( range( 2010, 2021)) - cf.years_test = [2021] #[2018] + + cf.years_train = [2021] # list( range( 1980, 2018)) + cf.years_val = [2021] #[2018] cf.month = None cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]] cf.time_sampling = 1 # sampling rate for time steps - # file and data parameter parameter - cf.data_smoothing = 0 - cf.file_shape = (-1, 721, 1440) - cf.num_t_samples = 31*24 - cf.num_files_train = 5 - cf.num_files_test = 2 - cf.num_patches_per_t_train = 8 - cf.num_patches_per_t_test = 4 # random seeds cf.torch_seed = torch.initial_seed() # training params - cf.batch_size_validation = 1 #64 - cf.batch_size = 96 #16 #4 # 32 + cf.batch_size_validation = 64 + cf.batch_size = 32 cf.num_epochs = 128 + cf.num_samples_per_epoch = 4096*12 + cf.num_samples_validate = 128*12 + cf.num_loader_workers = 8 # additional infos cf.size_token_info = 8 @@ -186,18 +155,18 @@ def train() : cf.grad_checkpointing = True cf.with_cls = False # network config + cf.with_mixed_precision = True cf.with_layernorm = True cf.coupling_num_heads_per_field = 1 cf.dropout_rate = 0.05 - cf.learnable_mask = False cf.with_qk_lnorm = False # encoder - cf.encoder_num_layers = 6 #10 #4 + cf.encoder_num_layers = 4 cf.encoder_num_heads = 16 cf.encoder_num_mlp_layers = 2 cf.encoder_att_type = 'dense' # decoder - cf.decoder_num_layers = 6 #10 #4 + cf.decoder_num_layers = 4 cf.decoder_num_heads = 16 cf.decoder_num_mlp_layers = 2 cf.decoder_self_att = False @@ -208,24 +177,23 @@ def train() : cf.net_tail_num_nets = 16 cf.net_tail_num_layers = 0 # loss - # supported: see Trainer for supported losses - cf.losses = ['mse_ensemble', 'stats'] + cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps # training cf.optimizer_zero = False cf.lr_start = 5. * 10e-7 - cf.lr_max = 0.00005*3 - cf.lr_min = 0.00004 - cf.weight_decay = 0.05 + cf.lr_max = 0.00005 + cf.lr_min = 0.00002 + cf.weight_decay = 0.1 cf.lr_decay_rate = 1.025 cf.lr_start_epochs = 3 - # BERT - # strategies: 'BERT', 'forecast', 'temporal_interpolation', 'identity' - cf.BERT_strategy = 'BERT' + # strategies: 'BERT', 'forecast', 'temporal_interpolation' + cf.BERT_strategy = 'BERT' + cf.forecast_num_tokens = 1 # only needed / used for BERT_strategy 'forecast cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution - cf.forecast_num_tokens = 2 #only when training in forecast mode + # debug / output cf.log_test_num_ranks = 0 cf.save_grads = False @@ -239,25 +207,23 @@ def train() : cf.with_wandb = True setup_wandb( cf.with_wandb, cf, par_rank, 'train', mode='offline') - cf.with_mixed_precision = True - cf.num_samples_per_epoch = 4096*12 - cf.num_samples_validate = 128*12 - cf.num_loader_workers = 6 - - #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' - # # in steps x lat_degrees x lon_degrees - #cf.n_size = [36, 1*9*6, 1.*9*12] - - # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' - # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' - #cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' - # - #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' - # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' - cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' - # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' - # in steps x lat_degrees x lon_degrees - cf.n_size = [36, 0.25*9*6, 0.25*9*12] + # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk32.zarr' + # # # in steps x lat_degrees x lon_degrees + # cf.n_size = [36, 1*9*6, 1.*9*12] + + # # # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' + # # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' + # # # + # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' + # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' + # # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' + # # # in steps x lat_degrees x lon_degrees + # cf.n_size = [36, 0.25*9*6, 0.25*9*12] + + cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' + cf.n_size = [36, 1*9*6, 1.*9*12] if cf.with_wandb and 0 == cf.par_rank : cf.write_json( wandb) @@ -268,14 +234,18 @@ def train() : #################################################################################################### if __name__ == '__main__': + + try : + + train() -# train() + # wandb_id, epoch, epoch_continue = '1jh2qvrx', 392, 392 + # Trainer = Trainer_BERT + # train_continue( wandb_id, epoch, Trainer, epoch_continue) - #wandb_id, epoch = '66zlffty', 26 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 - wandb_id, epoch = 'h7orvjna', 82 - #wandb_id, epoch = 'ocpn87si', 103 - #wandb_id, epoch = 'fc5o31h2', 27 - epoch_continue = epoch + except : + + extype, value, tb = sys.exc_info() + traceback.print_exc() + pdb.post_mortem(tb) - Trainer = Trainer_BERT - train_continue( wandb_id, epoch, Trainer, epoch_continue) diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 02b5b3e..ad13392 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -147,7 +147,7 @@ def run( self, epoch = -1) : lr=cf.lr_start ) else : self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=cf.lr_start, - weight_decay=cf.weight_decay) + weight_decay=cf.weight_decay) self.grad_scaler = torch.cuda.amp.GradScaler(enabled=cf.with_mixed_precision) @@ -571,15 +571,7 @@ def prepare_batch( self, xin) : cdev = devs[cf.fields[i][1][3]] tmi_out += [ [torch.cat(tmi_l).to( cdev, non_blocking=True) for tmi_l in tmi] ] self.tokens_masked_idx = tmi_out - - # learnable class token (cannot be done in the data loader since this is running in parallel) - if cf.learnable_mask : - for ifield, (source, _) in enumerate(batch_data) : - source = torch.flatten( torch.flatten( torch.flatten( source, 1, 4), 2, 4), 0, 1) - assert len(cf.fields[ifield][2]) == 1 - tmidx = self.tokens_masked_idx[ifield][0] - source[ tmidx ] = self.model.net.masks[ifield].to( source.device) - + return batch_data ################################################### diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index fb2c40c..bdc8412 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -163,7 +163,9 @@ def setup_ddp( with_ddp = True) : rank = 0 size = 1 - if with_ddp : + master_node = os.environ.get('MASTER_ADDR', '-1') + + if with_ddp and (master_node != '-1'): local_rank = int(os.environ.get("SLURM_LOCALID")) ranks_per_node = int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] ) From 3dd417c6168cf351c6a679dfa0a8a4742c7a14a8 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 13:10:38 +0200 Subject: [PATCH 57/66] Fixed partial merge --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 522b71a..dd43a96 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,6 @@ packages=find_packages(), # if packages are available in a native form fo the host system then these should be used install_requires=['torch>=2.3', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo', 'pytest', 'cfgrib'], ->>>>>>> 3868ab41a354c82de5a5ac7c71877c46f8016278 data_files=[('./output', []), ('./logs', []), ('./results',[])], ) From 179b95fb9f76af4c25ea965b638b0afac39f53dd Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 29 Jul 2024 13:13:46 +0200 Subject: [PATCH 58/66] Cleaned up code. --- atmorep/core/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 2f9bcc6..adf13eb 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -222,7 +222,8 @@ def train() : # # # in steps x lat_degrees x lon_degrees # cf.n_size = [36, 0.25*9*6, 0.25*9*12] - cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' cf.n_size = [36, 1*9*6, 1.*9*12] if cf.with_wandb and 0 == cf.par_rank : From 679adb3b093e87a7262cf9ac935d32804801c88f Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Tue, 30 Jul 2024 19:57:39 +0200 Subject: [PATCH 59/66] validated code --- atmorep/core/evaluate.py | 21 ++++++++++----------- atmorep/core/evaluator.py | 8 +++++--- atmorep/tests/validation_test.py | 3 ++- atmorep/utils/utils.py | 4 ++++ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 8a2583f..90bb3a1 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -29,8 +29,7 @@ # model_id = '15oisw8d' # velocity_z #model_id = '3qou60es' # temperature (also 2147fkco) #model_id = '2147fkco' # temperature (also 2147fkco) - #model_id='s3wwcc3j' - + # multi-field configurations with either velocity or voritcity+divergence #model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity @@ -46,24 +45,24 @@ #Add 'attention' : True to options to store the attention maps. NB. supported only for single field runs. # BERT masked token model - mode, options = 'BERT', {'years_test' : [2021], 'num_samples_validate' : 10, 'with_pytest' : True } + mode, options = 'BERT', {'years_test' : [2021], 'num_samples_validate' : 128, 'with_pytest' : True } # BERT forecast mode - #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'num_samples_validate' : 10, 'with_pytest' : True } + #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'num_samples_validate' : 128, 'with_pytest' : True } #temporal interpolation #idx_time_mask: list of relative time positions of the masked tokens within the cube wrt num_tokens[0] - #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'num_samples_validate' : 10, 'with_pytest' : True} + #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'num_samples_validate' : 128, 'with_pytest' : True} # BERT forecast with patching to obtain global forecast # mode, options = 'global_forecast', { # 'dates' : [[2021, 1, 10, 18]], -# # 'dates' : [ #[2021, 1, 10, 18] -# # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], -# # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], -# # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], -# # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] -# # ], +# # # 'dates' : [ #[2021, 1, 10, 18] +# # # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], +# # # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], +# # # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], +# # # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] +# # # ], # 'token_overlap' : [0, 0], # 'forecast_num_tokens' : 2, # 'with_pytest' : True } diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index a83fcce..d5e778b 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -128,8 +128,9 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : cf.lat_sampling_weighted = False cf.BERT_strategy = 'BERT' cf.log_test_num_ranks = 4 - cf.num_samples_validate = 128 #1472 + cf.num_samples_validate = 10 #28 #1472 Evaluator.parse_args( cf, args) + utils.check_num_samples(cf.num_samples_validate, cf.batch_size) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## @@ -142,7 +143,7 @@ def forecast( cf, model_id, model_epoch, devices, args = {}) : cf.forecast_num_tokens = 1 # will be overwritten when user specified cf.num_samples_validate = 128 #128 Evaluator.parse_args( cf, args) - + utils.check_num_samples(cf.num_samples_validate, cf.batch_size) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## @@ -217,8 +218,9 @@ def temporal_interpolation( cf, model_id, model_epoch, devices, args = {}) : # set/over-write options cf.BERT_strategy = 'temporal_interpolation' cf.log_test_num_ranks = 4 - cf.num_samples_validate = 10 #128 + cf.num_samples_validate = 128 Evaluator.parse_args( cf, args) + utils.check_num_samples(cf.num_samples_validate, cf.batch_size) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py index e4e30d2..4dbea9a 100644 --- a/atmorep/tests/validation_test.py +++ b/atmorep/tests/validation_test.py @@ -64,7 +64,8 @@ def test_datetime(field, model_id, BERT, epoch = 0): continue era5 = xr.open_dataset(era5_path, engine = "cfgrib")[grib_index(field)].sel(time = datetime, latitude = lats, longitude = lons) - assert (data[0] == era5.values[0]).all(), "Mismatch between ERA5 and AtmoRep Timestamps" + #assert (data[0] == era5.values[0]).all(), "Mismatch between ERA5 and AtmoRep Timestamps" + assert np.isclose(data[0], era5.values[0],rtol=1e-04, atol=1e-07).all(), "Mismatch between ERA5 and AtmoRep Timestamps" ############################################################################# diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index bdc8412..866e2d3 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -405,3 +405,7 @@ def get_weights(lats_idx, lat_min = -90., lat_max = 90., reso = 0.25): def weighted_mse(x, target, weights): return torch.sum(weights * (x - target) **2 )/torch.sum(weights) +######################################## + +def check_num_samples(num_samples_validate, batch_size): + assert num_samples_validate // batch_size > 0, f"Num samples validate: {num_samples_validate} is smaller than batch size: {batch_size}. Please increase it." From 224bf918123bbbced1a97564e9c716d01338f503 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 5 Aug 2024 17:48:25 +0200 Subject: [PATCH 60/66] Fixed bug when number of samples is larger than available samples. --- atmorep/datasets/multifield_data_sampler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index eb5a14c..6088e0c 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -101,6 +101,8 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, idxs_years = np.logical_or( idxs_years, self.times.year == year) self.idxs_years = np.where( idxs_years)[0] + self.num_samples = min( self.num_samples, self.idxs_years.shape[0]) + ################################################### def shuffle( self) : @@ -124,6 +126,7 @@ def shuffle( self) : ################################################### def __iter__(self): + if self.with_shuffle : self.shuffle() @@ -133,8 +136,9 @@ def __iter__(self): res = self.res iter_start, iter_end = self.worker_workset() - + for bidx in range( iter_start, iter_end) : + sources, token_infos = [[] for _ in self.fields], [[] for _ in self.fields] sources_infos, source_idxs = [], [] From 5c77311ddbdebdb95799e9a14de050b250e88814 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 7 Aug 2024 14:09:35 +0200 Subject: [PATCH 61/66] restore settings for batch size 96 --- atmorep/core/train.py | 49 ++++++++++----------- atmorep/datasets/multifield_data_sampler.py | 2 +- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index adf13eb..601a2ac 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -15,21 +15,17 @@ #################################################################################################### import torch -import numpy as np import os import sys -import pdb import traceback import wandb -import atmorep.config.config as config from atmorep.core.trainer import Trainer_BERT from atmorep.utils.utils import Config from atmorep.utils.utils import setup_ddp from atmorep.utils.utils import setup_wandb from atmorep.utils.utils import init_torch -import atmorep.utils.utils as utils #################################################################################################### @@ -86,7 +82,7 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : #################################################################################################### def train() : - num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) + num_accs_per_task = 1 #int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -110,15 +106,17 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # cf.fields = [ [ 'temperature', [ 1, 512, [ ], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + + cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - # cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ] ] - # cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] @@ -134,7 +132,7 @@ def train() : cf.fields_targets = [] - cf.years_train = [2021] # list( range( 1980, 2018)) + cf.years_train = list( range( 2010, 2021)) cf.years_val = [2021] #[2018] cf.month = None cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]] @@ -142,8 +140,8 @@ def train() : # random seeds cf.torch_seed = torch.initial_seed() # training params - cf.batch_size_validation = 64 - cf.batch_size = 32 + cf.batch_size_validation = 1 #64 + cf.batch_size = 96 cf.num_epochs = 128 cf.num_samples_per_epoch = 4096*12 cf.num_samples_validate = 128*12 @@ -161,12 +159,12 @@ def train() : cf.dropout_rate = 0.05 cf.with_qk_lnorm = False # encoder - cf.encoder_num_layers = 4 + cf.encoder_num_layers = 6 cf.encoder_num_heads = 16 cf.encoder_num_mlp_layers = 2 cf.encoder_att_type = 'dense' # decoder - cf.decoder_num_layers = 4 + cf.decoder_num_layers = 6 cf.decoder_num_heads = 16 cf.decoder_num_mlp_layers = 2 cf.decoder_self_att = False @@ -177,19 +175,19 @@ def train() : cf.net_tail_num_nets = 16 cf.net_tail_num_layers = 0 # loss - cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps + cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps, weighted_mse # training cf.optimizer_zero = False cf.lr_start = 5. * 10e-7 - cf.lr_max = 0.00005 - cf.lr_min = 0.00002 - cf.weight_decay = 0.1 + cf.lr_max = 0.00005*3 + cf.lr_min = 0.00004 #0.00002 + cf.weight_decay = 0.05 #0.1 cf.lr_decay_rate = 1.025 cf.lr_start_epochs = 3 # BERT # strategies: 'BERT', 'forecast', 'temporal_interpolation' cf.BERT_strategy = 'BERT' - cf.forecast_num_tokens = 1 # only needed / used for BERT_strategy 'forecast + cf.forecast_num_tokens = 2 # only needed / used for BERT_strategy 'forecast cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution @@ -219,12 +217,13 @@ def train() : # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' # # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' + cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/' # # # in steps x lat_degrees x lon_degrees - # cf.n_size = [36, 0.25*9*6, 0.25*9*12] + cf.n_size = [36, 0.25*9*6, 0.25*9*12] # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' - cf.n_size = [36, 1*9*6, 1.*9*12] + # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' + # cf.n_size = [36, 1*9*6, 1.*9*12] if cf.with_wandb and 0 == cf.par_rank : cf.write_json( wandb) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index eb5a14c..0a90076 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -184,7 +184,7 @@ def __iter__(self): cdata = data_t[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] normalizer = self.normalizers[ifield][ilevel] - if corr_type != 'global': + if corr_type != 'global': normalizer = normalizer[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) From 11f5d7ed340b7ab01433201a5ea73f52ce7aa4c4 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 7 Aug 2024 14:45:16 +0200 Subject: [PATCH 62/66] restore paths --- atmorep/core/train.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 601a2ac..73dccda 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -82,7 +82,7 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : #################################################################################################### def train() : - num_accs_per_task = 1 #int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) + num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) device = init_torch( num_accs_per_task) with_ddp = True par_rank, par_size = setup_ddp( with_ddp) @@ -106,15 +106,9 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - # cf.fields = [ [ 'temperature', [ 1, 512, [ ], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - - cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], - [ 96, 105, 114, 123, 137 ], - [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] - + cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] cf.fields_prediction = [ [cf.fields[0][0], 1.] ] # cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ], @@ -217,13 +211,13 @@ def train() : # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' # # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' - cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/' + # cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/' # # # in steps x lat_degrees x lon_degrees - cf.n_size = [36, 0.25*9*6, 0.25*9*12] + # cf.n_size = [36, 0.25*9*6, 0.25*9*12] # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' - # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' - # cf.n_size = [36, 1*9*6, 1.*9*12] + cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' + cf.n_size = [36, 1*9*6, 1.*9*12] if cf.with_wandb and 0 == cf.par_rank : cf.write_json( wandb) From 2afc62b38c053dc8ff6b7bfbeca99159aeb7de13 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 7 Aug 2024 16:42:01 +0200 Subject: [PATCH 63/66] add BSC Epicure team suggestions --- atmorep/core/atmorep_model.py | 6 ++- atmorep/core/train.py | 33 +++++++++------ atmorep/datasets/multifield_data_sampler.py | 45 +++++++++++---------- 3 files changed, 47 insertions(+), 37 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 4673fcc..2d22638 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -137,9 +137,11 @@ def mode( self, mode : NetMode) : if mode == NetMode.train : self.data_loader_iter = iter(self.data_loader_train) + #self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : self.data_loader_iter = iter(self.data_loader_test) + #self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -187,7 +189,7 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non cf.batch_size, pre_batch, cf.n_size, cf.num_samples_per_epoch, with_shuffle = (cf.BERT_strategy != 'global_forecast'), - with_source_idxs = True ) + with_source_idxs = True, compute_weights = (cf.losses.count('weighted_mse') > 0) ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) @@ -195,7 +197,7 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non cf.batch_size_validation, pre_batch, cf.n_size, cf.num_samples_validate, with_shuffle = (cf.BERT_strategy != 'global_forecast'), - with_source_idxs = True ) + with_source_idxs = True, compute_weights = (cf.losses.count('weighted_mse') > 0) ) self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params, sampler = None) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 73dccda..9a9b5ea 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -18,7 +18,7 @@ import os import sys import traceback - +import pdb import wandb from atmorep.core.trainer import Trainer_BERT @@ -106,22 +106,29 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], + cf.fields = [ [ 'temperature', [ 1, 512, [ ], 0 ], [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - # cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ], + # cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] + + # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - # cf.fields = [ [ 'velocity_z', [ 1, 1024, [ ], 0 ], + + # cf.fields = [ [ 'velocity_v', [ 1, 1024, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ], + # cf.fields = [ [ 'velocity_z', [ 1, 512, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + + # cf.fields = [ [ 'specific_humidity', [ 1, 1024, [ ], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] # [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ] cf.fields_targets = [] @@ -135,7 +142,7 @@ def train() : cf.torch_seed = torch.initial_seed() # training params cf.batch_size_validation = 1 #64 - cf.batch_size = 96 + cf.batch_size = 96 cf.num_epochs = 128 cf.num_samples_per_epoch = 4096*12 cf.num_samples_validate = 128*12 @@ -211,13 +218,13 @@ def train() : # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' # # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' - # cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/' + cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/' # # # in steps x lat_degrees x lon_degrees - # cf.n_size = [36, 0.25*9*6, 0.25*9*12] + cf.n_size = [36, 0.25*9*6, 0.25*9*12] # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' - cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' - cf.n_size = [36, 1*9*6, 1.*9*12] + #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' + #cf.n_size = [36, 1*9*6, 1.*9*12] if cf.with_wandb and 0 == cf.par_rank : cf.write_json( wandb) diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index df8a7a9..fce0202 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -32,7 +32,7 @@ class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, - num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False, + num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False, compute_weights = False, fields_targets = None, pre_batch_targets = None ) : ''' Data set for single dynamic field at an arbitrary number of vertical levels @@ -46,6 +46,7 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, self.n_size = n_size self.num_samples = num_samples self.with_source_idxs = with_source_idxs + self.compute_weights = compute_weights self.with_shuffle = with_shuffle self.pre_batch = pre_batch @@ -185,11 +186,11 @@ def __iter__(self): source_data, tok_info = [], [] # extract data, normalize and tokenize - cdata = data_t[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] - + cdata = data_t[... , lat_ran[:,np.newaxis], lon_ran] + normalizer = self.normalizers[ifield][ilevel] if corr_type != 'global': - normalizer = normalizer[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] + normalizer = normalizer[ ... , lat_ran[:,np.newaxis], lon_ran] cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) source_data = tokenize( torch.from_numpy( cdata), tok_size ) @@ -217,29 +218,29 @@ def __iter__(self): tmidx_list = sources[-1] weights_idx_list = [] + if self.compute_weights: + for ifield, field_info in enumerate(self.fields): + weights = [] + for ilevel, vl in enumerate(field_info[2]): + for ibatch in range(self.batch_size): + + lats_idx = source_idxs[ibatch][1] + lons_idx = source_idxs[ibatch][2] - for ifield, field_info in enumerate(self.fields): - weights = [] - for ilevel, vl in enumerate(field_info[2]): - for ibatch in range(self.batch_size): - - lats_idx = source_idxs[ibatch][1] - lons_idx = source_idxs[ibatch][2] - - idx_base = tmidx_list[ifield][ilevel][ibatch] - idx_loc = idx_base - np.prod(num_tokens) * ibatch - - grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 - grid = torch.from_numpy( np.array( np.broadcast_to( grid, - shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) + idx_base = tmidx_list[ifield][ilevel][ibatch] + idx_loc = idx_base - np.prod(num_tokens) * ibatch + + grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + grid = torch.from_numpy( np.array( np.broadcast_to( grid, + shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) - grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2) + grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2) - lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) - weights.append([get_weights(la) for la in lats_mskd_b]) + weights.append([get_weights(la) for la in lats_mskd_b]) - weights_idx_list.append(weights) + weights_idx_list.append(weights) sources = (*sources, weights_idx_list) # TODO: implement (only required when prediction target comes from different data stream) From df7b23c0a8d37045ceea60ffbb57e5965efe01bb Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 7 Aug 2024 17:40:34 +0200 Subject: [PATCH 64/66] fix local norm --- atmorep/core/train.py | 2 +- atmorep/datasets/multifield_data_sampler.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 9a9b5ea..5574ec1 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -106,7 +106,7 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'temperature', [ 1, 512, [ ], 0 ], + cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], [ 96, 105, 114, 123, 137 ], [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] cf.fields_prediction = [ [cf.fields[0][0], 1.] ] diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index fce0202..5e2c6f8 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -186,11 +186,11 @@ def __iter__(self): source_data, tok_info = [], [] # extract data, normalize and tokenize - cdata = data_t[... , lat_ran[:,np.newaxis], lon_ran] - + cdata = data_t[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] + normalizer = self.normalizers[ifield][ilevel] if corr_type != 'global': - normalizer = normalizer[ ... , lat_ran[:,np.newaxis], lon_ran] + normalizer = normalizer[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) source_data = tokenize( torch.from_numpy( cdata), tok_size ) From a73be45852b2273c114e8b320c44d5e8a38a478d Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 8 Aug 2024 15:45:21 +0200 Subject: [PATCH 65/66] fix evaluate.py for years_val --- atmorep/core/evaluate.py | 2 +- atmorep/core/evaluator.py | 15 +++++++-------- atmorep/core/train.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 90bb3a1..ea25dba 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -41,7 +41,7 @@ # global_forecast_range # options can be used to over-write parameters in config; some modes also have specific options, # e.g. global_forecast where a start date can be specified - + #Add 'attention' : True to options to store the attention maps. NB. supported only for single field runs. # BERT masked token model diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index d5e778b..efac76c 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -55,11 +55,6 @@ def parse_args( cf, args) : @staticmethod def run( cf, model_id, model_epoch, devices) : - if not hasattr(cf, 'batch_size'): - cf.batch_size = cf.batch_size_max - if not hasattr(cf, 'batch_size_validation'): - cf.batch_size_validation = cf.batch_size_max - cf.with_mixed_precision = True # set/over-write options as desired @@ -82,7 +77,7 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : else : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) devices = init_torch( num_accs_per_task) - #devices = ['cuda'] + devices = ['cuda:1'] par_rank, par_size = setup_ddp( with_ddp) cf = Config().load_json( model_id) @@ -112,6 +107,12 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : cf.with_mixed_precision = False if not hasattr(cf, 'with_pytest'): cf.with_pytest = False + if not hasattr(cf, 'batch_size'): + cf.batch_size = cf.batch_size_max + if not hasattr(cf, 'batch_size_validation'): + cf.batch_size_validation = cf.batch_size_max + if not hasattr(cf, 'years_val'): + cf.years_val = cf.years_test func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) @@ -159,8 +160,6 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : cf.batch_size = 196 #14 if not hasattr(cf, 'batch_size_validation'): cf.batch_size_validation = 1 #64 - if not hasattr(cf, 'batch_size_delta'): - cf.batch_size_delta = 8 if not hasattr(cf, 'num_samples_validate'): cf.num_samples_validate = 196 #if not hasattr(cf,'with_mixed_precision'): diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 5574ec1..c92c3e0 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -106,16 +106,16 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], - [ 96, 105, 114, 123, 137 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] - cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + # cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] + # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - # cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] + cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] - # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] # cf.fields = [ [ 'velocity_v', [ 1, 1024, [ ], 0 ], From ba60b71a78141fdb1082db2de431c0b06981fa41 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Mon, 12 Aug 2024 11:59:24 +0200 Subject: [PATCH 66/66] answer review comments --- atmorep/core/atmorep_model.py | 12 ++++++------ atmorep/core/evaluate.py | 2 +- atmorep/core/evaluator.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 2d22638..b26252c 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -137,11 +137,9 @@ def mode( self, mode : NetMode) : if mode == NetMode.train : self.data_loader_iter = iter(self.data_loader_train) - #self.data_loader_iter = iter(self.dataset_train) self.net.train() elif mode == NetMode.test : self.data_loader_iter = iter(self.data_loader_test) - #self.data_loader_iter = iter(self.dataset_test) self.net.eval() else : assert False @@ -188,16 +186,18 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, cf.batch_size, pre_batch, cf.n_size, cf.num_samples_per_epoch, - with_shuffle = (cf.BERT_strategy != 'global_forecast'), - with_source_idxs = True, compute_weights = (cf.losses.count('weighted_mse') > 0) ) + with_shuffle = (cf.BERT_strategy != 'global_forecast'), + with_source_idxs = True, + compute_weights = (cf.losses.count('weighted_mse') > 0) ) self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, sampler = None) self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val, cf.batch_size_validation, pre_batch, cf.n_size, cf.num_samples_validate, - with_shuffle = (cf.BERT_strategy != 'global_forecast'), - with_source_idxs = True, compute_weights = (cf.losses.count('weighted_mse') > 0) ) + with_shuffle = (cf.BERT_strategy != 'global_forecast'), + with_source_idxs = True, + compute_weights = (cf.losses.count('weighted_mse') > 0) ) self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params, sampler = None) diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index ea25dba..90bb3a1 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -41,7 +41,7 @@ # global_forecast_range # options can be used to over-write parameters in config; some modes also have specific options, # e.g. global_forecast where a start date can be specified - + #Add 'attention' : True to options to store the attention maps. NB. supported only for single field runs. # BERT masked token model diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index efac76c..b5fab37 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -77,7 +77,7 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : else : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) devices = init_torch( num_accs_per_task) - devices = ['cuda:1'] + #devices = ['cuda:1'] par_rank, par_size = setup_ddp( with_ddp) cf = Config().load_json( model_id)