diff --git a/atmorep/config/config.py b/atmorep/config/config.py index 53755f2..6c0fe01 100644 --- a/atmorep/config/config.py +++ b/atmorep/config/config.py @@ -3,12 +3,8 @@ fpath = os.path.dirname(os.path.realpath(__file__)) -year_base = 1979 -year_last = 2022 - path_models = Path( fpath, '../../models/') -path_results = Path( fpath, '../../results/') -path_data = Path( fpath, '../../data/') +path_results = Path( fpath, '../../results') path_plots = Path( fpath, '../results/plots/') grib_index = { 'vorticity' : 'vo', 'divergence' : 'd', 'geopotential' : 'z', diff --git a/atmorep/core/atmorep_model.py b/atmorep/core/atmorep_model.py index 3c4de6a..b26252c 100644 --- a/atmorep/core/atmorep_model.py +++ b/atmorep/core/atmorep_model.py @@ -35,6 +35,7 @@ from atmorep.transformer.transformer_decoder import TransformerDecoder from atmorep.transformer.tail_ensemble import TailEnsemble + #################################################################################################### class AtmoRepData( torch.nn.Module) : @@ -53,37 +54,6 @@ def __init__( self, net) : self.rng_seed = net.cf.rng_seed if not self.rng_seed : self.rng_seed = int(torch.randint( 100000000, (1,))) - - ################################################### - def load_data( self, mode : NetMode, batch_size = -1, num_loader_workers = -1) : - '''Load data''' - - cf = self.net.cf - - if batch_size < 0 : - batch_size = cf.batch_size_max - if num_loader_workers < 0 : - num_loader_workers = cf.num_loader_workers - - if mode == NetMode.train : - self.data_loader_train = self._load_data( self.dataset_train, batch_size, num_loader_workers) - elif mode == NetMode.test : - batch_size = cf.batch_size_test - self.data_loader_test = self._load_data( self.dataset_test, batch_size, num_loader_workers) - else : - assert False - - ################################################### - def _load_data( self, dataset, batch_size, num_loader_workers) : - '''Private implementation for load''' - - dataset.load_data( batch_size) - - loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False, - 'num_workers': num_loader_workers, 'pin_memory': True} - data_loader = torch.utils.data.DataLoader( dataset, **loader_params, sampler = None) - - return data_loader ################################################### def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_workers = -1) : @@ -94,7 +64,7 @@ def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_worke dataset = self.dataset_train if mode == NetMode.train else self.dataset_test dataset.set_data( times_pos, batch_size) - + self._set_data( dataset, mode, batch_size, num_loader_workers) ################################################### @@ -103,7 +73,6 @@ def set_global( self, mode : NetMode, times, batch_size = -1, num_loader_workers cf = self.net.cf if batch_size < 0 : batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test - dataset = self.dataset_train if mode == NetMode.train else self.dataset_test dataset.set_global( times, batch_size, cf.token_overlap) @@ -143,7 +112,7 @@ def _set_data( self, dataset, mode : NetMode, batch_size = -1, loader_workers = assert False ################################################### - def normalizer( self, field, vl_idx) : + def normalizer( self, field, vl_idx, lats_idx, lons_idx ) : if isinstance( field, str) : for fidx, field_info in enumerate(self.cf.fields) : @@ -153,12 +122,15 @@ def normalizer( self, field, vl_idx) : normalizer = self.dataset_train.datasets[fidx].normalizer elif isinstance( field, int) : - normalizer = self.dataset_train.datasets[field][vl_idx].normalizer - + normalizer = self.dataset_train.normalizers[field][vl_idx] + if len(normalizer.shape) > 2: + normalizer = np.take( np.take( normalizer, lats_idx, -2), lons_idx, -1) else : assert False, 'invalid argument type (has to be index to cf.fields or field name)' + + year_base = self.dataset_train.year_base - return normalizer + return normalizer, year_base ################################################### def mode( self, mode : NetMode) : @@ -193,8 +165,8 @@ def forward( self, xin) : return pred ################################################### - def get_attention( self, xin): #, field_idx) : - attn = self.net.get_attention( xin) #, field_idx) + def get_attention( self, xin) : + attn = self.net.get_attention( xin) return attn ################################################### @@ -208,40 +180,26 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non self.pre_batch_targets = pre_batch_targets cf = self.net.cf - self.dataset_train = MultifieldDataSampler( cf.data_dir, cf.years_train, cf.fields, - batch_size = cf.batch_size_start, - num_t_samples = cf.num_t_samples, - num_patches_per_t = cf.num_patches_per_t_train, - num_load = cf.num_files_train, - pre_batch = self.pre_batch, - rng_seed = self.rng_seed, - file_shape = cf.file_shape, - smoothing = cf.data_smoothing, - level_type = cf.level_type, - file_format = cf.file_format, - month = cf.month, - time_sampling = cf.time_sampling, - geo_range = cf.geo_range_sampling, - fields_targets = cf.fields_targets, - pre_batch_targets = self.pre_batch_targets ) - - self.dataset_test = MultifieldDataSampler( cf.data_dir, cf.years_test, cf.fields, - batch_size = cf.batch_size_test, - num_t_samples = cf.num_t_samples, - num_patches_per_t = cf.num_patches_per_t_test, - num_load = cf.num_files_test, - pre_batch = self.pre_batch, - rng_seed = self.rng_seed, - file_shape = cf.file_shape, - smoothing = cf.data_smoothing, - level_type = cf.level_type, - file_format = cf.file_format, - month = cf.month, - time_sampling = cf.time_sampling, - geo_range = cf.geo_range_sampling, - lat_sampling_weighted = cf.lat_sampling_weighted, - fields_targets = cf.fields_targets, - pre_batch_targets = self.pre_batch_targets ) + loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False, + 'num_workers': cf.num_loader_workers, 'pin_memory': True} + + self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train, + cf.batch_size, + pre_batch, cf.n_size, cf.num_samples_per_epoch, + with_shuffle = (cf.BERT_strategy != 'global_forecast'), + with_source_idxs = True, + compute_weights = (cf.losses.count('weighted_mse') > 0) ) + self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params, + sampler = None) + + self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val, + cf.batch_size_validation, + pre_batch, cf.n_size, cf.num_samples_validate, + with_shuffle = (cf.BERT_strategy != 'global_forecast'), + with_source_idxs = True, + compute_weights = (cf.losses.count('weighted_mse') > 0) ) + self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params, + sampler = None) return self @@ -261,7 +219,6 @@ def create( self, devices, load_pretrained=True) : cf = self.cf self.devices = devices - size_token_info = 6 self.fields_coupling_idx = [] self.fields_index = {} @@ -294,17 +251,9 @@ def create( self, devices, load_pretrained=True) : self.embeds = torch.nn.ModuleList() self.encoders = torch.nn.ModuleList() - self.masks = torch.nn.ParameterList() for field_idx, field_info in enumerate(cf.fields) : - # learnabl class token - if cf.learnable_mask : - mask = torch.nn.Parameter( 0.1 * torch.randn( np.prod( field_info[4]), requires_grad=True)) - self.masks.append( mask.to(devices[0])) - else : - self.masks.append( None) - # encoder self.encoders.append( TransformerEncoder( cf, field_idx, True).create()) # load pre-trained model if specified @@ -356,11 +305,10 @@ def create( self, devices, load_pretrained=True) : device = self.devices[0] if len(field_info[1]) > 3 : assert field_info[1][3] < 4, 'Only single node model parallelism supported' + print(devices, field_info[1][3]) assert field_info[1][3] < len(devices), 'Per field device id larger than max devices' device = self.devices[ field_info[1][3] ] # set device - if self.masks[field_idx] != None : - self.masks[field_idx].to(device) self.embeds[field_idx].to(device) self.encoders[field_idx].to(device) @@ -418,6 +366,68 @@ def load_block( self, field_info, block_name, block ) : print( 'Loaded {} for {} from id = {} (ignoring/missing {} elements).'.format( block_name, field_info[0], field_info[1][4][0], len(mkeys) ) ) + ################################################### + def translate_weights(self, mloaded, mkeys, ukeys): + ''' + Function used for backward compatibility + ''' + cf = self.cf + + #encoder: + for layer in range(cf.encoder_num_layers) : + + #shape([16, 3, 128, 2048]) + mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]]) + mloaded[f'encoders.0.heads.{layer}.proj_heads.weight'] = mw + + for head in range(cf.encoder_num_heads): + del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight'] + + #cross attention + if f'encoders.0.heads.{layer}.heads_other.0.proj_qs.weight' in ukeys: + mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]]) + + for i in range(cf.encoder_num_heads): + del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight'] + del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight'] + + else: + dim_mw = self.encoders[0].heads[0].proj_heads_other[0].weight.shape + mw = torch.tensor(np.zeros(dim_mw)) + + mloaded[f'encoders.0.heads.{layer}.proj_heads_other.0.weight'] = mw + + #decoder + for iblock in range(0, 19, 2) : + mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads.{head}.proj_{k}.weight'] for head in range(8) for k in ["qs", "ks", "vs"]]) + mloaded[f'decoders.0.blocks.{iblock}.proj_heads.weight'] = mw + + qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_qs.weight'] for head in range(8)] + mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_{k}.weight'] for head in range(8) for k in ["ks", "vs"]]) + + mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_q.weight'] = torch.cat([*qs]) + mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_kv.weight'] = mw + + #self.num_samples_validate + decoder_dim = self.decoders[0].blocks[iblock].ln_q.weight.shape #128 + mloaded[f'decoders.0.blocks.{iblock}.ln_q.weight'] = torch.tensor(np.ones(decoder_dim)) + mloaded[f'decoders.0.blocks.{iblock}.ln_k.weight'] = torch.tensor(np.ones(decoder_dim)) + mloaded[f'decoders.0.blocks.{iblock}.ln_q.bias'] = torch.tensor(np.ones(decoder_dim)) + mloaded[f'decoders.0.blocks.{iblock}.ln_k.bias'] = torch.tensor(np.ones(decoder_dim)) + + for i in range(8): + del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight'] + del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight'] + + return mloaded + ################################################### @staticmethod def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : @@ -429,7 +439,10 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : model = AtmoRep( cf).create( devices, load_pretrained=False) mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) ) - mkeys, _ = model.load_state_dict( mloaded, False ) + mkeys, ukeys = model.load_state_dict( mloaded, False ) + if (f'encoders.0.heads.0.proj_heads.weight') in mkeys: + mloaded = model.translate_weights(mloaded, mkeys, ukeys) + mkeys, ukeys = model.load_state_dict( mloaded, False ) if len(mkeys) > 0 : print( f'Loaded AtmoRep: ignoring {len(mkeys)} elements: {mkeys}') @@ -437,7 +450,7 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) : # TODO: remove, only for backward if model.embeds_token_info[0].weight.abs().max() == 0. : model.embeds_token_info = torch.nn.ModuleList() - + return model ################################################### @@ -474,8 +487,9 @@ def forward( self, xin) : # embedding cf = self.cf + fields_embed = self.get_fields_embed(xin) - + # attention maps (if requested) atts = [ [] for _ in cf.fields ] @@ -528,16 +542,14 @@ def forward_encoder_block( self, iblock, fields_embed) : return fields_embed_cur, atts ################################################### - def get_fields_embed( self, xin ) : - cf = self.cf if 0 == len(self.embeds_token_info) : # TODO: only for backward compatibility, remove emb_net_ti = self.embed_token_info - return [prepare_token( field_data, emb_net, emb_net_ti, cf.with_cls ) + return [prepare_token( field_data, emb_net, emb_net_ti ) for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))] else : embs_net_ti = self.embeds_token_info - return [prepare_token( field_data, emb_net, embs_net_ti[fidx], cf.with_cls ) + return [prepare_token( field_data, emb_net, embs_net_ti[fidx] ) for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))] ################################################### diff --git a/atmorep/core/evaluate.py b/atmorep/core/evaluate.py index 3033040..90bb3a1 100644 --- a/atmorep/core/evaluate.py +++ b/atmorep/core/evaluate.py @@ -15,43 +15,60 @@ #################################################################################################### from atmorep.core.evaluator import Evaluator +import time if __name__ == '__main__': # models for individual fields - model_id = '4nvwbetz' # vorticity - # model_id = 'oxpycr7w' # divergence - # model_id = '1565pb1f' # specific_humidity - # model_id = '3kdutwqb' # total precip - # model_id = 'dys79lgw' # velocity_u - # model_id = '22j6gysw' # velocity_v + #model_id = '4nvwbetz' # vorticity + #model_id = 'oxpycr7w' # divergence + #model_id = '1565pb1f' # specific_humidity + #model_id = '3kdutwqb' # total precip + model_id = 'dys79lgw' # velocity_u + #model_id = '22j6gysw' # velocity_v # model_id = '15oisw8d' # velocity_z - # model_id = '3qou60es' # temperature (also 2147fkco) - # model_id = '2147fkco' # temperature (also 2147fkco) - + #model_id = '3qou60es' # temperature (also 2147fkco) + #model_id = '2147fkco' # temperature (also 2147fkco) + # multi-field configurations with either velocity or voritcity+divergence - # model_id = '1jh2qvrx' # multiformer, velocity + #model_id = '1jh2qvrx' # multiformer, velocity # model_id = 'wqqy94oa' # multiformer, vorticity - # model_id = '3cizyl1q' # 3 field config: u,v,T + #model_id = '3cizyl1q' # 3 field config: u,v,T # model_id = '1v4qk0qx' # pre-trained, 3h forecasting # model_id = '1m79039j' # pre-trained, 6h forecasting - + #model_id='34niv2nu' # supported modes: test, forecast, fixed_location, temporal_interpolation, global_forecast, # global_forecast_range # options can be used to over-write parameters in config; some modes also have specific options, # e.g. global_forecast where a start date can be specified + + #Add 'attention' : True to options to store the attention maps. NB. supported only for single field runs. # BERT masked token model - # mode, options = 'BERT', {'years_test' : [2021], 'fields[0][2]' : [123, 137], 'attention' : False} - + mode, options = 'BERT', {'years_test' : [2021], 'num_samples_validate' : 128, 'with_pytest' : True } + # BERT forecast mode - # mode, options = 'forecast', {'forecast_num_tokens' : 1} #, 'fields[0][2]' : [123, 137], 'attention' : False } - + #mode, options = 'forecast', {'forecast_num_tokens' : 2, 'num_samples_validate' : 128, 'with_pytest' : True } + + #temporal interpolation + #idx_time_mask: list of relative time positions of the masked tokens within the cube wrt num_tokens[0] + #mode, options = 'temporal_interpolation', {'idx_time_mask': [5,6,7], 'num_samples_validate' : 128, 'with_pytest' : True} + # BERT forecast with patching to obtain global forecast - mode, options = 'global_forecast', { 'fields[0][2]' : [123, 137], - 'dates' : [[2021, 2, 10, 12]], - 'token_overlap' : [0, 0], - 'forecast_num_tokens' : 1, - 'attention' : False } +# mode, options = 'global_forecast', { +# 'dates' : [[2021, 1, 10, 18]], +# # # 'dates' : [ #[2021, 1, 10, 18] +# # # [2021, 1, 10, 12] , [2021, 1, 11, 0], [2021, 1, 11, 12], [2021, 1, 12, 0], [2021, 1, 12, 12], [2021, 1, 13, 0], +# # # [2021, 4, 10, 12], [2021, 4, 11, 0], [2021, 4, 11, 12], [2021, 4, 12, 0], [2021, 4, 12, 12], [2021, 4, 13, 0], +# # # [2021, 7, 10, 12], [2021, 7, 11, 0], [2021, 7, 11, 12], [2021, 7, 12, 0], [2021, 7, 12, 12], [2021, 7, 13, 0], +# # # [2021, 10, 10, 12], [2021, 10, 11, 0], [2021, 10, 11, 12], [2021, 10, 12, 0], [2021, 10, 12, 12], [2021, 10, 13, 0] +# # # ], +# 'token_overlap' : [0, 0], +# 'forecast_num_tokens' : 2, +# 'with_pytest' : True } - Evaluator.evaluate( mode, model_id, options) + file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr' + + now = time.time() + Evaluator.evaluate( mode, model_id, file_path, options) + print("time", time.time() - now) \ No newline at end of file diff --git a/atmorep/core/evaluator.py b/atmorep/core/evaluator.py index a1c0b75..541af5a 100644 --- a/atmorep/core/evaluator.py +++ b/atmorep/core/evaluator.py @@ -17,7 +17,7 @@ import numpy as np import os import code - +import pytest import datetime import wandb @@ -54,10 +54,12 @@ def parse_args( cf, args) : ############################################## @staticmethod def run( cf, model_id, model_epoch, devices) : + + cf.with_mixed_precision = True # set/over-write options as desired evaluator = Evaluator.load( cf, model_id, model_epoch, devices) - evaluator.model.load_data( NetMode.test) + if 0 == cf.par_rank : cf.print() cf.write_json( wandb) @@ -65,7 +67,7 @@ def run( cf, model_id, model_epoch, devices) : ############################################## @staticmethod - def evaluate( mode, model_id, args = {}, model_epoch=-2) : + def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) : # SLURM_TASKS_PER_NODE is controlled by #SBATCH --ntasks-per-node=1; should be 1 for multiformer with_ddp = True @@ -75,25 +77,50 @@ def evaluate( mode, model_id, args = {}, model_epoch=-2) : else : num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] )) devices = init_torch( num_accs_per_task) - par_rank, par_size = setup_ddp( with_ddp) + #devices = ['cuda:1'] + par_rank, par_size = setup_ddp( with_ddp) cf = Config().load_json( model_id) + cf.file_path = file_path cf.with_wandb = True cf.with_ddp = with_ddp cf.par_rank = par_rank cf.par_size = par_size + cf.losses = cf.losses # overwrite old config - cf.data_dir = str(config.path_data) cf.attention = False setup_wandb( cf.with_wandb, cf, par_rank, '', mode='offline') if 0 == cf.par_rank : print( 'Running Evaluate.evaluate with mode =', mode) - cf.num_loader_workers = cf.loader_num_workers - cf.data_dir = config.path_data + # if not hasattr( cf, 'num_loader_workers'): + cf.num_loader_workers = 12 #cf.loader_num_workers + cf.rng_seed = None + + #backward compatibility + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + #cf.n_size = [36, 0.25*27*2, 0.25*27*4] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = False + if not hasattr(cf, 'with_pytest'): + cf.with_pytest = False + if not hasattr(cf, 'batch_size'): + cf.batch_size = cf.batch_size_max + if not hasattr(cf, 'batch_size_validation'): + cf.batch_size_validation = cf.batch_size_max + if not hasattr(cf, 'years_val'): + cf.years_val = cf.years_test func = getattr( Evaluator, mode) func( cf, model_id, model_epoch, devices, args) + + if cf.with_pytest: + fields = [field[0] for field in cf.fields_prediction] + for field in fields: + pytest.main(["-x", "./atmorep/tests/validation_test.py", "--field", field, "--model_id", cf.wandb_id, "--strategy", cf.BERT_strategy]) ############################################## @staticmethod @@ -102,9 +129,9 @@ def BERT( cf, model_id, model_epoch, devices, args = {}) : cf.lat_sampling_weighted = False cf.BERT_strategy = 'BERT' cf.log_test_num_ranks = 4 - + cf.num_samples_validate = 10 #28 #1472 Evaluator.parse_args( cf, args) - + utils.check_num_samples(cf.num_samples_validate, cf.batch_size) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## @@ -115,24 +142,32 @@ def forecast( cf, model_id, model_epoch, devices, args = {}) : cf.BERT_strategy = 'forecast' cf.log_test_num_ranks = 4 cf.forecast_num_tokens = 1 # will be overwritten when user specified - + cf.num_samples_validate = 128 #128 Evaluator.parse_args( cf, args) - + utils.check_num_samples(cf.num_samples_validate, cf.batch_size) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## @staticmethod def global_forecast( cf, model_id, model_epoch, devices, args = {}) : - - cf.BERT_strategy = 'forecast' + + cf.BERT_strategy = 'global_forecast' cf.batch_size_test = 24 - cf.num_loader_workers = 1 + cf.num_loader_workers = 12 #1 cf.log_test_num_ranks = 1 + + if not hasattr(cf, 'batch_size'): + cf.batch_size = 196 #14 + if not hasattr(cf, 'batch_size_validation'): + cf.batch_size_validation = 1 #64 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 196 + #if not hasattr(cf,'with_mixed_precision'): + cf.with_mixed_precision = True Evaluator.parse_args( cf, args) dates = args['dates'] - evaluator = Evaluator.load( cf, model_id, model_epoch, devices) evaluator.model.set_global( NetMode.test, np.array( dates)) if 0 == cf.par_rank : @@ -145,13 +180,16 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) : def global_forecast_range( cf, model_id, model_epoch, devices, args = {}) : cf.forecast_num_tokens = 2 - cf.BERT_strategy = 'forecast' - cf.token_overlap = [2, 6] + cf.BERT_strategy = 'global_forecast' + cf.token_overlap = [0, 0] cf.batch_size_test = 24 cf.num_loader_workers = 1 cf.log_test_num_ranks = 1 - + cf.batch_size_start = 14 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 196 + Evaluator.parse_args( cf, args) if 0 == cf.par_rank : @@ -179,7 +217,9 @@ def temporal_interpolation( cf, model_id, model_epoch, devices, args = {}) : # set/over-write options cf.BERT_strategy = 'temporal_interpolation' cf.log_test_num_ranks = 4 - + cf.num_samples_validate = 128 + Evaluator.parse_args( cf, args) + utils.check_num_samples(cf.num_samples_validate, cf.batch_size) Evaluator.run( cf, model_id, model_epoch, devices) ############################################## diff --git a/atmorep/core/train.py b/atmorep/core/train.py index 51f290e..c92c3e0 100644 --- a/atmorep/core/train.py +++ b/atmorep/core/train.py @@ -15,18 +15,17 @@ #################################################################################################### import torch -import numpy as np import os - +import sys +import traceback +import pdb import wandb -import atmorep.config.config as config from atmorep.core.trainer import Trainer_BERT from atmorep.utils.utils import Config from atmorep.utils.utils import setup_ddp from atmorep.utils.utils import setup_wandb from atmorep.utils.utils import init_torch -import atmorep.utils.utils as utils #################################################################################################### @@ -47,7 +46,17 @@ def train_continue( wandb_id, epoch, Trainer, epoch_continue = -1) : # name has changed but ensure backward compatibility if hasattr( cf, 'loader_num_workers') : cf.num_loader_workers = cf.loader_num_workers - + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = True + if not hasattr(cf, 'years_val'): + cf.years_val = cf.years_test + # any parameter in cf can be overwritten when training is continued, e.g. we can increase the # masking rate # cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ], @@ -87,12 +96,6 @@ def train() : cf.num_accs_per_task = num_accs_per_task # number of GPUs / accelerators per task cf.par_rank = par_rank cf.par_size = par_size - cf.back_passes_per_step = 4 - # general - cf.comment = '' - cf.file_format = 'grib' - cf.data_dir = str(config.path_data) - cf.level_type = 'ml' # format: list of fields where for each field the list is # [ name , @@ -103,83 +106,66 @@ def train() : # [ total masking rate, rate masking, rate noising, rate for multi-res distortion] # ] - cf.fields = [ [ 'vorticity', [ 1, 2048, [ ], 0 ], - [ 123 ], - [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], + # [ 96, 105, 114, 123, 137 ], + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] + # cf.fields_prediction = [ [cf.fields[0][0], 1.] ] - # cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0], - # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ] ] - - # cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ], - # [ 96, 105, 114, 123, 137 ], - # [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ], + [ 96, 105, 114, 123, 137 ], + [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ] - # cf.fields = [ [ 'velocity_z', [ 1, 1024, [ ], 0 ], + cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + + + # cf.fields = [ [ 'velocity_v', [ 1, 1024, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] + # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ], + # cf.fields = [ [ 'velocity_z', [ 1, 512, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields = [ [ 'temperature', [ 1, 1536, [ ], 0 ], + # cf.fields = [ [ 'specific_humidity', [ 1, 1024, [ ], 0 ], # [ 96, 105, 114, 123, 137 ], - # [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ] - - # cf.fields = [ [ 'total_precip', [ 1, 2048, [ ], 0 ], - # [ 0 ], - # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ] - - # cf.fields = [ [ 'geopotential', [ 1, 1024, [], 0 ], - # [ 0 ], - # [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ] - # cf.fields_prediction = [ ['geopotential', 1.] ] - - - cf.fields_prediction = [ [cf.fields[0][0], 1.] ] + # [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ] + # [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ] cf.fields_targets = [] - cf.years_train = [2021] # list( range( 1980, 2018)) - cf.years_test = [2021] #[2018] + + cf.years_train = list( range( 2010, 2021)) + cf.years_val = [2021] #[2018] cf.month = None cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]] cf.time_sampling = 1 # sampling rate for time steps - # file and data parameter parameter - cf.data_smoothing = 0 - cf.file_shape = (-1, 721, 1440) - cf.num_t_samples = 31*24 - cf.num_files_train = 5 - cf.num_files_test = 2 - cf.num_patches_per_t_train = 8 - cf.num_patches_per_t_test = 4 # random seeds cf.torch_seed = torch.initial_seed() # training params - cf.batch_size_test = 64 - cf.batch_size_start = 16 - cf.batch_size_max = 32 - cf.batch_size_delta = 8 + cf.batch_size_validation = 1 #64 + cf.batch_size = 96 cf.num_epochs = 128 + cf.num_samples_per_epoch = 4096*12 + cf.num_samples_validate = 128*12 cf.num_loader_workers = 8 + # additional infos cf.size_token_info = 8 cf.size_token_info_net = 16 cf.grad_checkpointing = True cf.with_cls = False # network config + cf.with_mixed_precision = True cf.with_layernorm = True cf.coupling_num_heads_per_field = 1 cf.dropout_rate = 0.05 - cf.learnable_mask = False - cf.with_qk_lnorm = True + cf.with_qk_lnorm = False # encoder - cf.encoder_num_layers = 10 + cf.encoder_num_layers = 6 cf.encoder_num_heads = 16 cf.encoder_num_mlp_layers = 2 cf.encoder_att_type = 'dense' # decoder - cf.decoder_num_layers = 10 + cf.decoder_num_layers = 6 cf.decoder_num_heads = 16 cf.decoder_num_mlp_layers = 2 cf.decoder_self_att = False @@ -190,33 +176,28 @@ def train() : cf.net_tail_num_nets = 16 cf.net_tail_num_layers = 0 # loss - # supported: see Trainer for supported losses - # cf.losses = ['mse', 'stats'] - cf.losses = ['mse_ensemble', 'stats'] - # cf.losses = ['mse'] - # cf.losses = ['stats'] - # cf.losses = ['crps'] + cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps, weighted_mse # training cf.optimizer_zero = False cf.lr_start = 5. * 10e-7 - cf.lr_max = 0.00005 - cf.lr_min = 0.00004 - cf.weight_decay = 0.05 + cf.lr_max = 0.00005*3 + cf.lr_min = 0.00004 #0.00002 + cf.weight_decay = 0.05 #0.1 cf.lr_decay_rate = 1.025 cf.lr_start_epochs = 3 - cf.lat_sampling_weighted = True # BERT - # strategies: 'BERT', 'forecast', 'temporal_interpolation', 'identity' - cf.BERT_strategy = 'BERT' - cf.BERT_window = False # sample sub-region + # strategies: 'BERT', 'forecast', 'temporal_interpolation' + cf.BERT_strategy = 'BERT' + cf.forecast_num_tokens = 2 # only needed / used for BERT_strategy 'forecast cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution + # debug / output cf.log_test_num_ranks = 0 cf.save_grads = False cf.profile = False - cf.test_initial = True + cf.test_initial = False cf.attention = False cf.rng_seed = None @@ -225,6 +206,26 @@ def train() : cf.with_wandb = True setup_wandb( cf.with_wandb, cf, par_rank, 'train', mode='offline') + # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk32.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk32.zarr' + # # # in steps x lat_degrees x lon_degrees + # cf.n_size = [36, 1*9*6, 1.*9*12] + + # # # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk16.zarr' + # # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk32.zarr' + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk32.zarr' + # # # + # # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr' + # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr' + # # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr' + cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/' + # # # in steps x lat_degrees x lon_degrees + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + + # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr' + #cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr' + #cf.n_size = [36, 1*9*6, 1.*9*12] + if cf.with_wandb and 0 == cf.par_rank : cf.write_json( wandb) cf.print() @@ -234,11 +235,18 @@ def train() : #################################################################################################### if __name__ == '__main__': + + try : - train() + train() + + # wandb_id, epoch, epoch_continue = '1jh2qvrx', 392, 392 + # Trainer = Trainer_BERT + # train_continue( wandb_id, epoch, Trainer, epoch_continue) + + except : + + extype, value, tb = sys.exc_info() + traceback.print_exc() + pdb.post_mortem(tb) -# wandb_id, epoch = '1jh2qvrx', -2 #'4nvwbetz', -2 #392 #'4nvwbetz', -2 -# epoch_continue = epoch -# -# Trainer = Trainer_BERT -# train_continue( wandb_id, epoch, Trainer, epoch_continue) diff --git a/atmorep/core/train_multi.py b/atmorep/core/train_multi.py index 3f057cd..1bb8fcf 100644 --- a/atmorep/core/train_multi.py +++ b/atmorep/core/train_multi.py @@ -45,6 +45,14 @@ def train_continue( model_id, model_epoch, Trainer, model_epoch_continue = -1) : cf.optimizer_zero = False if hasattr( cf, 'loader_num_workers') : cf.num_loader_workers = cf.loader_num_workers + if not hasattr( cf, 'n_size'): + cf.n_size = [36, 0.25*9*6, 0.25*9*12] + if not hasattr(cf, 'num_samples_per_epoch'): + cf.num_samples_per_epoch = 1024 + if not hasattr(cf, 'num_samples_validate'): + cf.num_samples_validate = 128 + if not hasattr(cf, 'with_mixed_precision'): + cf.with_mixed_precision = True setup_wandb( cf.with_wandb, cf, par_rank, 'train-multi', mode='offline') @@ -78,7 +86,7 @@ def train_multi() : # general cf.comment = '' cf.file_format = 'grib' - cf.data_dir = str(config.path_data) + cf.file_path = str(config.path_data) cf.level_type = 'ml' cf.fields = [ [ 'vorticity', [ 1, 2048, ['divergence', 'temperature'], 0 ], @@ -194,7 +202,6 @@ def train_multi() : cf.lat_sampling_weighted = False # BERT cf.BERT_strategy = 'BERT' # 'BERT', 'forecast', 'identity', 'totalmask' - cf.BERT_window = False # sample sub-region cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields # (fields need to have same BERT params for this to have effect) cf.BERT_mr_max = 2 # maximum reduction rate for resolution diff --git a/atmorep/core/trainer.py b/atmorep/core/trainer.py index 61f2574..ad13392 100644 --- a/atmorep/core/trainer.py +++ b/atmorep/core/trainer.py @@ -17,17 +17,14 @@ import torch import torchinfo import numpy as np +import time import code -# code.interact(local=locals()) from pathlib import Path import os import datetime -from typing import TypeVar import functools -import pandas as pd - import wandb # import horovod.torch as hvd import torch.distributed as dist @@ -43,15 +40,9 @@ import atmorep.utils.token_infos_transformations as token_infos_transformations -import atmorep.utils.utils as utils -from atmorep.utils.utils import shape_to_str -from atmorep.utils.utils import relMSELoss -from atmorep.utils.utils import Gaussian -from atmorep.utils.utils import CRPS -from atmorep.utils.utils import NetMode -from atmorep.utils.utils import sgn_exp - +from atmorep.utils.utils import Gaussian, CRPS, kernel_crps, weighted_mse, NetMode, tokenize, detokenize from atmorep.datasets.data_writer import write_forecast, write_BERT, write_attention +from atmorep.datasets.normalizer import denormalize #################################################################################################### class Trainer_Base() : @@ -91,7 +82,6 @@ def __init__( self, cf, devices ) : ################################################### def create( self, load_embeds=True) : - net = AtmoRep( self.cf) self.model = AtmoRepData( net) @@ -100,13 +90,11 @@ def create( self, load_embeds=True) : # TODO: pass the properly to model / net self.model.net.encoder_to_decoder = self.encoder_to_decoder self.model.net.decoder_to_tail = self.decoder_to_tail - return self ################################################### @classmethod def load( Typename, cf, model_id, epoch, devices) : - trainer = Typename( cf, devices).create( load_embeds=False) trainer.model.net = trainer.model.net.load( model_id, devices, cf, epoch) @@ -116,7 +104,6 @@ def load( Typename, cf, model_id, epoch, devices) : str = 'Loaded model id = {}{}.'.format( model_id, f' at epoch = {epoch}' if epoch> -2 else '') print( str) - return trainer ################################################### @@ -160,16 +147,16 @@ def run( self, epoch = -1) : lr=cf.lr_start ) else : self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=cf.lr_start, - weight_decay=cf.weight_decay) + weight_decay=cf.weight_decay) + self.grad_scaler = torch.cuda.amp.GradScaler(enabled=cf.with_mixed_precision) + if 0 == cf.par_rank : # print( self.model.net) model_parameters = filter(lambda p: p.requires_grad, self.model_ddp.parameters()) num_params = sum([np.prod(p.size()) for p in model_parameters]) print( f'Number of trainable parameters: {num_params:,}') - - # test at the beginning as reference - self.model.load_data( NetMode.test, batch_size=cf.batch_size_test) + if cf.test_initial : cur_test_loss = self.validate( epoch, cf.BERT_strategy).cpu().numpy() test_loss = np.array( [cur_test_loss]) @@ -177,14 +164,11 @@ def run( self, epoch = -1) : # generic value based on data normalization test_loss = np.array( [1.0]) epoch += 1 - - batch_size = cf.batch_size_start - cf.batch_size_delta - + if cf.profile : lr = learn_rates[epoch] for g in self.optimizer.param_groups: g['lr'] = lr - self.model.load_data( NetMode.train, batch_size = cf.batch_size_max) self.profile() # training loop @@ -197,12 +181,9 @@ def run( self, epoch = -1) : for g in self.optimizer.param_groups: g['lr'] = lr - batch_size = min( cf.batch_size_max, batch_size + cf.batch_size_delta) - tstr = datetime.datetime.now().strftime("%H:%M:%S") - print( '{} : {} :: batch_size = {}, lr = {}'.format( epoch, tstr, batch_size, lr) ) + print( '{} : {} :: batch_size = {}, lr = {}'.format( epoch, tstr, cf.batch_size, lr) ) - self.model.load_data( NetMode.train, batch_size = batch_size) self.train( epoch) if cf.with_wandb and 0 == cf.par_rank : @@ -240,17 +221,24 @@ def train( self, epoch): grad_loss_total = [] ctr = 0 + self.optimizer.zero_grad() + time_start = time.time() + for batch_idx in range( model.len( NetMode.train)) : + batch_data = self.model.next() - - batch_data = self.prepare_batch( batch_data) - preds, _ = self.model_ddp( batch_data) - - loss, mse_loss, losses = self.loss( preds, batch_idx) + _, _, _, tmksd_list, weight_list = batch_data[0] + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=cf.with_mixed_precision): + batch_data = self.prepare_batch( batch_data) + preds, _ = self.model_ddp( batch_data) + #breakpoint() + loss, mse_loss, losses = self.loss( preds, batch_idx, tmksd_list, weight_list) + + self.grad_scaler.scale(loss).backward() + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() [loss_total[idx].append( losses[key]) for idx, key in enumerate(losses)] mse_loss_total.append( mse_loss.detach().cpu() ) @@ -259,7 +247,7 @@ def train( self, epoch): # logging - if int((batch_idx * cf.batch_size_max) / 4) > ctr : + if int((batch_idx * cf.batch_size) / 8) > ctr : # wandb logging if cf.with_wandb and (0 == cf.par_rank) : @@ -277,14 +265,16 @@ def train( self, epoch): wandb.log( loss_dict ) # console output - print('train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:1.5f} : {:1.5f} :: {:1.5f}'.format( - epoch, batch_idx, model.len( NetMode.train), - 100. * batch_idx/model.len(NetMode.train), - torch.mean( torch.tensor( grad_loss_total)), torch.mean(torch.tensor(mse_loss_total)), - torch.mean( preds[0][1]) ), flush=True) + samples_sec = cf.batch_size / (time.time() - time_start) + str = 'epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:1.5f} : {:1.5f} :: {:1.5f} ({:2.2f} s/sec)' + print( str.format( epoch, batch_idx, model.len( NetMode.train), + 100. * batch_idx/model.len(NetMode.train), + torch.mean( torch.tensor( grad_loss_total)), + torch.mean(torch.tensor(mse_loss_total)), + torch.mean( preds[0][1]), samples_sec ), flush=True) # save model (use -2 as epoch to indicate latest, stored without epoch specification) - # self.save( -2) + self.save( -2) # reset loss_total = [[] for i in range(len(cf.losses)) ] @@ -293,6 +283,7 @@ def train( self, epoch): std_dev_total = [[] for i in range(len(self.fields_prediction_idx)) ] ctr += 1 + time_start = time.time() # save gradients if cf.save_grads and cf.with_wandb and (0 == cf.par_rank) : @@ -345,14 +336,15 @@ def profile( self): batch_data = self.model.next() - batch_data = self.prepare_batch( batch_data) - preds, _ = self.model_ddp( batch_data) - - loss, mse_loss, losses = self.loss( preds, batch_idx) + with torch.autocast(device_type='cuda',dtype=torch.float16, enabled=cf.with_mixed_precision): + batch_data = self.prepare_batch( batch_data) + preds, _ = self.model_ddp( batch_data) + loss, mse_loss, losses = self.loss( preds, batch_idx) + self.grad_scaler.scale(loss).backward() + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() self.optimizer.zero_grad() - # loss.backward() - # self.optimizer.step() prof.step() @@ -371,71 +363,49 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): test_len = 0 self.mode_test = True - - # run in training mode - offset = 0 - if -1 == epoch and 0 == cf.par_rank : - if 1 == cf.num_accs_per_task : # bug in torchinfo; fixed in v1.8.0 - offset += 1 - print( 'Network size:') - batch_data = self.model.next() - batch_data = self.prepare_batch( batch_data) - torchinfo.summary( self.model, input_data=[batch_data]) # run test set evaluation - with torch.no_grad() : - for it in range( self.model.len( NetMode.test) - offset) : - + for it in range( self.model.len( NetMode.test)) : batch_data = self.model.next() if cf.par_rank < cf.log_test_num_ranks : # keep on cpu since it will otherwise clog up GPU memory - (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] - # targets - if len(batch_data[1]) > 0 : - if type(batch_data[1][0][0]) is list : - targets = [batch_data[1][i][0][0] for i in range( len(batch_data[1]))] - else : - targets = batch_data[1][0] - # store on cpu + (sources, _ , targets, tmis_list, _) = batch_data[0] log_sources = ( [source.detach().clone().cpu() for source in sources ], - [ti.detach().clone().cpu() for ti in token_infos], [target.detach().clone().cpu() for target in targets ], - tmis, tmis_list ) - - batch_data = self.prepare_batch( batch_data) - - preds, atts = self.model( batch_data) + tmis_list) + with torch.autocast(device_type='cuda',dtype=torch.float16,enabled=cf.with_mixed_precision): + batch_data = self.prepare_batch( batch_data) + preds, atts = self.model( batch_data) loss = torch.tensor( 0.) ifield = 0 for pred, idx in zip( preds, self.fields_prediction_idx) : - + target = self.targets[idx] # hook for custom test loss self.test_loss( pred, target) # base line loss cur_loss = self.MSELoss( pred[0], target = target ).cpu().item() - + loss += cur_loss total_losses[ifield] += cur_loss ifield += 1 - + total_loss += loss test_len += 1 - + # store detailed results on current test set for book keeping if cf.par_rank < cf.log_test_num_ranks : log_preds = [[p.detach().clone().cpu() for p in pred] for pred in preds] self.log_validate( epoch, it, log_sources, log_preds) - if cf.attention: - self.log_attention( epoch, it, [atts, - [ti.detach().clone().cpu() for ti in token_infos]]) - + self.log_attention( epoch, it, atts) + # average over all nodes total_loss /= test_len * len(self.cf.fields_prediction) total_losses /= test_len + if cf.with_ddp : total_loss_cuda = total_loss.cuda() total_losses_cuda = total_losses.cuda() @@ -456,7 +426,6 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): loss_dict[idx_name] = total_losses[i] print( 'validation loss for {} : {}'.format( field[0], total_losses[i] )) wandb.log( loss_dict) - batch_data = [] torch.cuda.empty_cache() @@ -465,77 +434,13 @@ def validate( self, epoch, BERT_test_strategy = 'BERT'): return total_loss - ################################################### - def evaluate( self, data_idx = 0, log = True): - - cf = self.cf - self.model.mode( NetMode.test) - - log_sources = [] - test_len = 0 - - # evaluate - - loss = torch.tensor( 0.) - - with torch.no_grad() : - for it in range( self.model.len( NetMode.test)) : - - batch_data = self.model.next() - - if cf.par_rank < cf.log_test_num_ranks : - # keep on cpu since it will otherwise clog up GPU memory - (sources, token_infos, targets, tmis, tmis_list) = batch_data[0] - # targets - if len(batch_data[1]) > 0 : - targets = [] - for target_field in batch_data[1] : - targets.append(torch.cat([target_vl[0].unsqueeze(1) for target_vl in target_field],1)) - # store on cpu - log_sources = ( [source.detach().clone().cpu() for source in sources ], - [ti.detach().clone().cpu() for ti in token_infos], - [target.detach().clone().cpu() for target in targets ], - tmis, tmis_list ) - - batch_data = self.prepare_batch( batch_data) - - preds, atts = self.model( batch_data) - - ifield = 0 - for pred, idx in zip( preds, self.fields_prediction_idx) : - - target = self.targets[idx] - cur_loss = self.MSELoss( pred[0], target = target ).cpu() - loss += cur_loss - ifield += 1 - - test_len += 1 - - # logging - if cf.par_rank < cf.log_test_num_ranks : - self.log_validate( data_idx, it, log_sources, preds) - - if cf.attention: - self.log_attention( data_idx , it, [atts, - [ti.detach().clone().cpu() for ti in token_infos]]) - - # average over all nodes - loss /= test_len * len(self.cf.fields_prediction) - if cf.with_ddp : - loss_cuda = loss.cuda() - dist.all_reduce( loss_cuda, op=torch.distributed.ReduceOp.AVG ) - loss = loss_cuda.cpu() - - if 0 == cf.par_rank : - print( 'Loss {}'.format( loss)) - ################################################### def test_loss( self, pred, target) : '''Hook for custom test loss''' pass ################################################### - def loss( self, preds, batch_idx = 0) : + def loss( self, preds, batch_idx = 0, tmidx_list = None, weights_list = None) : # TODO: move implementations to individual files @@ -544,7 +449,6 @@ def loss( self, preds, batch_idx = 0) : losses = dict(zip(cf.losses,[[] for loss in cf.losses ])) for pred, idx in zip( preds, self.fields_prediction_idx) : - target = self.targets[idx] mse_loss = self.MSELoss( pred[0], target = target) @@ -559,9 +463,22 @@ def loss( self, preds, batch_idx = 0) : loss_en = torch.tensor( 0., device=target.device) for en in torch.transpose( pred[2], 1, 0) : loss_en += self.MSELoss( en, target = target) - # losses['mse_ensemble'].append( 50. * loss_en / pred[2].shape[1]) losses['mse_ensemble'].append( loss_en / pred[2].shape[1]) + if 'weighted_mse' in self.cf.losses : + loss_en = torch.tensor( 0., device=target.device) + field_info = cf.fields[idx] + token_size = field_info[4] + + weights = torch.Tensor(np.array([w for batch in weights_list[idx] for w in batch])) + weights = weights.view(*weights.shape, 1, 1).repeat(1, 1, token_size[0], token_size[2]).swapaxes(1, 2) + weights = weights.reshape([weights.shape[0], -1]).to(target.get_device()) + + for en in torch.transpose( pred[2], 1, 0) : + loss_en += weighted_mse( en, target, weights) + + losses['weighted_mse'].append( loss_en / pred[2].shape[1]) + # Generalized cross entroy loss for continuous distributions if 'stats' in self.cf.losses : stats_loss = Gaussian( target, pred[0], pred[1]) @@ -581,12 +498,19 @@ def loss( self, preds, batch_idx = 0) : crps_loss = torch.mean( CRPS( target, pred[0], pred[1])) losses['crps'].append( crps_loss) + if 'kernel_crps' in self.cf.losses : + kcrps_loss = torch.mean( kernel_crps( target,torch.transpose( pred[2], 1, 0))) + losses['kernel_crps'].append( kcrps_loss) + + loss = torch.tensor( 0., device=self.device_out) + tot_weight = torch.tensor( 0., device=self.device_out) for key in losses : - # print( 'LOSS : {} :: {}'.format( key, losses[key])) + #print( 'LOSS : {} :: {}'.format( key, losses[key])) for ifield, val in enumerate(losses[key]) : loss += self.loss_weights[ifield] * val.to( self.device_out) - loss /= len(self.cf.fields_prediction) * len( self.cf.losses) + tot_weight += self.loss_weights[ifield] + loss /= tot_weight mse_loss = mse_loss_total / len(self.cf.fields_prediction) return loss, mse_loss, losses @@ -623,8 +547,9 @@ def prepare_batch( self, xin) : # unpack loader output # xin[0] since BERT does not have targets - (sources, token_infos, targets, fields_tokens_masked_idx,fields_tokens_masked_idx_list) = xin[0] - + (sources, token_infos, targets, fields_tokens_masked_idx_list, _) = xin[0] + (self.sources_idxs, self.sources_info) = xin[2] + # network input batch_data = [ ( sources[i].to( devs[ cf.fields[i][1][3] ], non_blocking=True), self.tok_infos_trans(token_infos[i]).to( self.devices[0], non_blocking=True)) @@ -641,22 +566,12 @@ def prepare_batch( self, xin) : self.targets.append( targets[ifield].to( devs[cf.fields[ifield][1][3]], non_blocking=True )) # idxs of masked tokens - tmi_out = [[] for _ in range(len(fields_tokens_masked_idx))] - for i,tmi in enumerate(fields_tokens_masked_idx) : - tmi_out[i] = [tmi_l.to( devs[cf.fields[i][1][3]], non_blocking=True) for tmi_l in tmi] + tmi_out = [ ] + for i,tmi in enumerate(fields_tokens_masked_idx_list) : + cdev = devs[cf.fields[i][1][3]] + tmi_out += [ [torch.cat(tmi_l).to( cdev, non_blocking=True) for tmi_l in tmi] ] self.tokens_masked_idx = tmi_out - - # idxs of masked tokens per batch entry - self.fields_tokens_masked_idx_list = fields_tokens_masked_idx_list - - # learnable class token (cannot be done in the data loader since this is running in parallel) - if cf.learnable_mask : - for ifield, (source, _) in enumerate(batch_data) : - source = torch.flatten( torch.flatten( torch.flatten( source, 1, 4), 2, 4), 0, 1) - assert len(cf.fields[ifield][2]) == 1 - tmidx = self.tokens_masked_idx[ifield][0] - source[ tmidx ] = self.model.net.masks[ifield].to( source.device) - + return batch_data ################################################### @@ -674,48 +589,15 @@ def decoder_to_tail( self, idx_pred, pred) : # select "fixed" masked tokens for loss computation - # recover vertical level dimension - num_tokens = self.num_tokens[field_idx] - num_vlevels = len(self.cf.fields[field_idx][2]) # flatten token dimensions: remove space-time separation pred = torch.flatten( pred, 2, 3).to( dev) - # extract masked token level by level pred_masked = [] for lidx, level in enumerate(self.cf.fields[field_idx][2]) : - # select masked tokens, flattened along batch dimension for easier indexing and processing pred_l = torch.flatten( pred[:,lidx], 0, 1) - pred_masked_l = pred_l[ target_idx[lidx] ] - target_idx_l = target_idx[lidx] - - # add positional encoding of masked tokens - - # # TODO: do we need the positional encoding? - - # compute space time indices of all tokens - target_idxs_v = level * torch.ones( target_idx_l.shape[0], device=dev) - num_tokens_space = num_tokens[1] * num_tokens[2] - # remove offset introduced by linearization - target_idx_l = torch.remainder( target_idx_l, np.prod(num_tokens)) - target_idxs_t = (target_idx_l / num_tokens_space).int() - temp = torch.remainder( target_idx_l, num_tokens_space) - target_idxs_x = (temp / num_tokens[1]).int() - target_idxs_y = torch.remainder( temp, num_tokens[2]) - - # apply harmonic positional encoding - dim_embed = pred.shape[-1] - pe = torch.zeros( pred_masked_l.shape[0], dim_embed, device=dev) - xs = (2. * np.pi / dim_embed) * torch.arange( 0, dim_embed, 2, device=dev) - pe[:, 0::2] = 0.5 * torch.sin( torch.outer( 8 * target_idxs_x, xs) ) \ - + torch.sin( torch.outer( target_idxs_t, xs) ) - pe[:, 1::2] = 0.5 * torch.cos( torch.outer( 8 * target_idxs_y, xs) ) \ - + torch.cos( torch.outer( target_idxs_v, xs) ) - - # TODO: with or without final positional encoding? - # pred_masked.append( pred_masked_l + pe) - pred_masked.append( pred_masked_l) - + pred_masked.append( pred_l[ target_idx[lidx] ]) + # flatten along level dimension, for loss evaluation we effectively have level, batch, ... # as ordering of dimensions pred_masked = torch.cat( pred_masked, 0) @@ -729,9 +611,9 @@ def log_validate( self, epoch, bidx, log_sources, log_preds) : if not hasattr( self.cf, 'wandb_id') : return - if 'forecast' == self.cf.BERT_strategy : + if 'forecast' in self.cf.BERT_strategy : self.log_validate_forecast( epoch, bidx, log_sources, log_preds) - elif 'BERT' == self.cf.BERT_strategy : + elif 'BERT' in self.cf.BERT_strategy or 'temporal_interpolation' == self.cf.BERT_strategy : self.log_validate_BERT( epoch, bidx, log_sources, log_preds) else : assert False @@ -739,119 +621,125 @@ def log_validate( self, epoch, bidx, log_sources, log_preds) : ################################################### def log_validate_forecast( self, epoch, batch_idx, log_sources, log_preds) : '''Logging for BERT_strategy=forecast.''' - cf = self.cf - detok = utils.detokenize - # TODO, TODO: for 6h forecast we need to iterate over predicted token slices + cf = self.cf # save source: remains identical so just save ones - (sources, token_infos, targets, _, _) = log_sources + (sources, targets, _) = log_sources sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] - + batch_size = len(self.sources_info) # reconstruct geo-coords (identical for all fields) forecast_num_tokens = 1 if hasattr( cf, 'forecast_num_tokens') : forecast_num_tokens = cf.forecast_num_tokens - - num_tokens = cf.fields[0][3] - token_size = cf.fields[0][4] - lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) - lats, lons = [ ], [ ] - for tinfo in token_infos[0] : - lat_min, lat_max = tinfo[0][4], tinfo[ num_tokens[1]*num_tokens[2]-1 ][4] - lon_min, lon_max = tinfo[0][5], tinfo[ num_tokens[1]*num_tokens[2]-1 ][5] - res = tinfo[0][-1] - lat = torch.arange( lat_min - lat_d_h*res, lat_max + lat_d_h*res + 0.001, res) - if lon_max < lon_min : - lon = torch.arange( lon_min - lon_d_h*res, 360. + lon_max + lon_d_h*res + 0.001, res) - else : - lon = torch.arange( lon_min - lon_d_h*res, lon_max + lon_d_h*res + 0.001, res) - lats.append( lat.numpy()) - lons.append( torch.remainder( lon, 360.).numpy()) - - # check that last token (bottom right corner) has the expected coords - # assert np.allclose( ) - - # extract dates for each token entry, constant for each batch and field - dates_t = [] - for b_token_infos in token_infos[0] : - dates_t.append(utils.token_info_to_time(b_token_infos[0])-pd.Timedelta(hours=token_size[0]-1)) - - # TODO: check that last token matches first one - - # process input fields + + coords = [] for fidx, field_info in enumerate(cf.fields) : # reshape from tokens to contiguous physical field num_levels = len(field_info[2]) - source = detok( sources[fidx].cpu().detach().numpy()) + source = detokenize( sources[fidx].cpu().detach().numpy()) # recover tokenized shape - target = detok( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, - forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) - # TODO: check that geo-coords match to general ones that have been pre-determined - for bidx in range(token_infos[fidx].shape[0]) : + target = detokenize( targets[fidx].cpu().detach().numpy().reshape( [ num_levels, -1, + forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) + + coords_b = [] + + for bidx in range(batch_size): + dates = self.sources_info[bidx][0] + lats = self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] + dates_t = self.sources_info[bidx][0][ -forecast_num_tokens*field_info[4][0] : ] + + lats_idx = self.sources_idxs[bidx][1] + lons_idx = self.sources_idxs[bidx][2] + for vidx, _ in enumerate(field_info[2]) : - denormalize = self.model.normalizer( fidx, vidx).denormalize - date, coords = dates_t[bidx], [lats[bidx], lons[bidx]] - source[bidx,vidx] = denormalize( date.year, date.month, source[bidx,vidx], coords) - target[bidx,vidx] = denormalize( date.year, date.month, target[bidx,vidx], coords) + normalizer, year_base = self.model.normalizer( fidx, vidx, lats_idx, lons_idx) + source[bidx,vidx] = denormalize( source[bidx,vidx], normalizer, dates, year_base) + target[bidx,vidx] = denormalize( target[bidx,vidx], normalizer, dates_t, year_base) + + coords_b += [[dates, 90.-lats, lons, dates_t]] + # append sources_out.append( [field_info[0], source]) targets_out.append( [field_info[0], target]) + coords.append(coords_b) # process predicted fields for fidx, fn in enumerate(cf.fields_prediction) : - # field_info = cf.fields[ self.fields_prediction_idx[fidx] ] num_levels = len(field_info[2]) # predictions pred = log_preds[fidx][0].cpu().detach().numpy() - pred = detok( pred.reshape( [ num_levels, -1, + pred = detokenize( pred.reshape( [ num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(0,1)) # ensemble ensemble = log_preds[fidx][2].cpu().detach().numpy().swapaxes(0,1) - ensemble = detok( ensemble.reshape( [ cf.net_tail_num_nets, num_levels, -1, + ensemble = detokenize( ensemble.reshape( [ cf.net_tail_num_nets, num_levels, -1, forecast_num_tokens, *field_info[3][1:], *field_info[4] ]).swapaxes(1, 2)).swapaxes(0,1) # denormalize - for bidx in range(token_infos[fidx].shape[0]) : + for bidx in range(batch_size) : + lats = self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] + dates_t = self.sources_info[bidx][0][ -forecast_num_tokens*field_info[4][0] : ] + for vidx, vl in enumerate(field_info[2]) : - denormalize = self.model.normalizer( self.fields_prediction_idx[fidx], vidx).denormalize - date, coords = dates_t[bidx], [lats[bidx], lons[bidx]] - pred[bidx,vidx] = denormalize( date.year, date.month, pred[bidx,vidx], coords) - ensemble[bidx,:,vidx] = denormalize(date.year, date.month, ensemble[bidx,:,vidx], coords) + normalizer, year_base = self.model.normalizer( self.fields_prediction_idx[fidx], vidx, lats_idx, lons_idx) + pred[bidx,vidx] = denormalize( pred[bidx,vidx], normalizer, dates_t, year_base) + ensemble[bidx,:,vidx] = denormalize(ensemble[bidx,:,vidx], normalizer, dates_t, year_base) + # append preds_out.append( [fn[0], pred]) ensembles_out.append( [fn[0], ensemble]) - # generate time range - dates_sources, dates_targets = [ ], [ ] - for bidx in range( source.shape[0]) : - r = pd.date_range( start=dates_t[bidx], periods=source.shape[2], freq='h') - dates_sources.append( r.to_pydatetime().astype( 'datetime64[s]') ) - dates_targets.append( dates_sources[-1][ -forecast_num_tokens*token_size[0] : ] ) - levels = np.array(cf.fields[0][2]) - lats = [90.-lat for lat in lats] - - write_forecast( cf.wandb_id, epoch, batch_idx, - levels, sources_out, [dates_sources, lats, lons], - targets_out, [dates_targets, lats, lons], - preds_out, ensembles_out ) + write_forecast( cf.wandb_id, epoch, batch_idx, + levels, sources_out, + targets_out, preds_out, + ensembles_out, coords) + + ################################################### + + def split_data(self, data, idx_list, token_size) : + lens_batches = [[len(t) for t in tt] for tt in idx_list] + lens_levels = [torch.tensor( tt).sum() for tt in lens_batches] + data_b = torch.split( data, lens_levels) + # split according to batch + return [torch.split( data_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] + + def get_masked_data(self, field_info, data, idx_list, ensemble = False): + + cf = self.cf + batch_size = len(self.sources_info) + num_levels = len(field_info[2]) + num_tokens = field_info[3] + token_size = field_info[4] + data_b = self.split_data(data, idx_list, token_size) + + # recover token shape + if ensemble: + return [[data_b[vidx][bidx].reshape([-1, cf.net_tail_num_nets, *token_size]) + for bidx in range(batch_size)] + for vidx in range(num_levels)] + else: + return [[data_b[vidx][bidx].reshape([-1, *token_size]) for bidx in range(batch_size)] + for vidx in range(num_levels)] + ################################################### def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : '''Logging for BERT_strategy=BERT.''' cf = self.cf - detok = utils.detokenize + batch_size = len(self.sources_info) # save source: remains identical so just save ones - (sources, token_infos, targets, tokens_masked_idx, tokens_masked_idx_list) = log_sources + (sources, targets, tokens_masked_idx_list) = log_sources sources_out, targets_out, preds_out, ensembles_out = [ ], [ ], [ ], [ ] - sources_dates_out, sources_lats_out, sources_lons_out = [ ], [ ], [ ] - targets_dates_out, targets_lats_out, targets_lons_out = [ ], [ ], [ ] + coords = [] for fidx, field_info in enumerate(cf.fields) : @@ -860,173 +748,107 @@ def log_validate_BERT( self, epoch, batch_idx, log_sources, log_preds) : num_levels = len(field_info[2]) num_tokens = field_info[3] token_size = field_info[4] - lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) - tinfos = token_infos[fidx].reshape( [-1, num_levels, *num_tokens, cf.size_token_info]) - res = tinfos[0,0,0,0,0][-1].item() - batch_size = tinfos.shape[0] - - sources_b = detok( sources[fidx].numpy()) - + sources_b = detokenize( sources[fidx].numpy()) + if is_predicted : - # split according to levels - lens_levels = [t.shape[0] for t in tokens_masked_idx[fidx]] - targets_b = torch.split( targets[fidx], lens_levels) - preds_mu_b = torch.split( log_preds[fidx][0], lens_levels) - preds_ens_b = torch.split( log_preds[fidx][2], lens_levels) - # split according to batch - lens_batches = [ [bv.shape[0] for bv in b] for b in tokens_masked_idx_list[fidx] ] - targets_b = [torch.split( targets_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] - preds_mu_b = [torch.split(preds_mu_b[vidx], lens) for vidx,lens in enumerate(lens_batches)] - preds_ens_b =[torch.split(preds_ens_b[vidx],lens) for vidx,lens in enumerate(lens_batches)] - # recover token shape - targets_b = [[targets_b[vidx][bidx].reshape([-1, *token_size]) - for bidx in range(batch_size)] - for vidx in range(num_levels)] - preds_mu_b = [[preds_mu_b[vidx][bidx].reshape([-1, *token_size]) - for bidx in range(batch_size)] - for vidx in range(num_levels)] - preds_ens_b = [[preds_ens_b[vidx][bidx].reshape( [-1, cf.net_tail_num_nets, *token_size]) - for bidx in range(batch_size)] - for vidx in range(num_levels)] + targets_b = self.get_masked_data(field_info, targets[fidx], tokens_masked_idx_list[fidx]) + preds_mu_b = self.get_masked_data(field_info, log_preds[fidx][0], tokens_masked_idx_list[fidx]) + preds_ens_b = self.get_masked_data(field_info, log_preds[fidx][2], tokens_masked_idx_list[fidx], ensemble = True) # for all batch items coords_b = [] - for bidx, tinfo in enumerate(tinfos) : - - # use first vertical levels since a column is considered - lats = np.arange(tinfo[0,0,0,0,4]-lat_d_h*res, tinfo[0,0,-1,0,4]+lat_d_h*res+0.001,res) - if tinfo[0,0,0,-1,5] < tinfo[0,0,0,0,5] : - lons = np.remainder( np.arange( tinfo[0,0,0,0,5] - lon_d_h*res, - 360. + tinfo[0,0,0,-1,5] + lon_d_h*res + 0.001, res), 360.) - else : - lons = np.arange(tinfo[0,0,0,0,5]-lon_d_h*res, tinfo[0,0,0,-1,5]+lon_d_h*res+0.001,res) - lons = np.remainder( lons, 360.) - - # time stamp in token_infos is at start time so needs to be advanced by token_size[0]-1 - s = utils.token_info_to_time( tinfo[0,0,0,0,:3] ) - pd.Timedelta(hours=token_size[0]-1) - e = utils.token_info_to_time( tinfo[0,-1,0,0,:3] ) - dates = pd.date_range( start=s, end=e, freq='h') + for bidx in range(batch_size): + dates = self.sources_info[bidx][0] + lats = self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] + + lats_idx = self.sources_idxs[bidx][1] + lons_idx = self.sources_idxs[bidx][2] # target etc are aliasing targets_b which simplifies bookkeeping below if is_predicted : - target = [targets_b[vidx][bidx] for vidx in range(num_levels)] - pred_mu = [preds_mu_b[vidx][bidx] for vidx in range(num_levels)] + target = [targets_b[vidx][bidx] for vidx in range(num_levels)] + pred_mu = [preds_mu_b[vidx][bidx] for vidx in range(num_levels)] pred_ens = [preds_ens_b[vidx][bidx] for vidx in range(num_levels)] - dates_masked_l, lats_masked_l, lons_masked_l = [], [], [] + coords_mskd_l = [] for vidx, _ in enumerate(field_info[2]) : - normalizer = self.model.normalizer( fidx, vidx) - y, m = dates[0].year, dates[0].month - sources_b[bidx,vidx] = normalizer.denormalize( y, m, sources_b[bidx,vidx], [lats, lons]) + normalizer, year_base = self.model.normalizer( fidx, vidx, lats_idx, lons_idx) + sources_b[bidx,vidx] = denormalize(sources_b[bidx,vidx], normalizer, dates, year_base = 2021) if is_predicted : - - # TODO: make sure normalizer_local / normalizer_global is used in data_loader idx = tokens_masked_idx_list[fidx][vidx][bidx] - tinfo_masked = tinfos[bidx,vidx].flatten( 0,2) - tinfo_masked = tinfo_masked[idx] - lad, lod = lat_d_h*res, lon_d_h*res - lats_masked, lons_masked, dates_masked = [], [], [] - for t in tinfo_masked : - - lats_masked.append( np.expand_dims( np.arange(t[4]-lad, t[4]+lad+0.001,res), 0)) - lons_masked.append( np.expand_dims( np.arange(t[5]-lod, t[5]+lod+0.001,res), 0)) - - r = pd.date_range( start=utils.token_info_to_time(t), periods=token_size[0], freq='h') - dates_masked.append( np.expand_dims(r.to_pydatetime().astype( 'datetime64[s]'), 0) ) - - lats_masked = np.concatenate( lats_masked, 0) - lons_masked = np.remainder( np.concatenate( lons_masked, 0), 360.) - dates_masked = np.concatenate( dates_masked, 0) - - for ii,(t,p,e,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], - lats_masked, lons_masked)) : - targets_b[vidx][bidx][ii] = normalizer.denormalize( y, m, t, [la, lo]) - preds_mu_b[vidx][bidx][ii] = normalizer.denormalize( y, m, p, [la, lo]) - preds_ens_b[vidx][bidx][ii] = normalizer.denormalize( y, m, e, [la, lo]) - - dates_masked_l += [ dates_masked ] - lats_masked_l += [ [90.-lat for lat in lats_masked] ] - lons_masked_l += [ lons_masked ] - - dates = dates.to_pydatetime().astype( 'datetime64[s]') - - coords_b += [ [dates, 90.-lats, lons, dates_masked_l, lats_masked_l, lons_masked_l] ] - + grid = np.flip(np.array( np.meshgrid( lons, lats)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + grid_idx = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + + # recover time dimension since idx assumes the full space-time cube + grid = torch.from_numpy( np.array( np.broadcast_to( grid, + shape = [token_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) + grid_lats_toked = tokenize( grid[0], token_size).flatten( 0, 2) + grid_lons_toked = tokenize( grid[1], token_size).flatten( 0, 2) + + idx_loc = idx - np.prod(num_tokens) * bidx + #save only useful info for each bidx. shape e.g. [n_bidx, lat_token_size*lat_num_tokens] + lats_mskd = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + lons_mskd = np.array([np.unique(t) for t in grid_lons_toked[ idx_loc ].numpy()]) + + #time: idx ranges from 0->863 12x6x12 + t_idx = (idx_loc // (num_tokens[1]*num_tokens[2])) * token_size[0] + #create range from t_idx-2 to t_idx + t_idx = np.array([np.arange(t, t + token_size[0]) for t in t_idx]) + dates_mskd = dates[t_idx] + + for ii,(t,p,e,da,la,lo) in enumerate(zip( target[vidx], pred_mu[vidx], pred_ens[vidx], + dates_mskd, lats_mskd, lons_mskd)) : + normalizer_ii = normalizer + if len(normalizer.shape) > 2: #local normalization + lats_mskd_idx = np.where(np.isin(lats,la))[0] + lons_mskd_idx = np.where(np.isin(lons,lo))[0] + #normalizer_ii = normalizer[:, :, lats_mskd_idx, lons_mskd_idx] problems in python 3.9 + normalizer_ii = normalizer[:, :, lats_mskd_idx[0]:lats_mskd_idx[-1]+1, lons_mskd_idx[0]:lons_mskd_idx[-1]+1] + + targets_b[vidx][bidx][ii] = denormalize(t, normalizer_ii, da, year_base) + preds_mu_b[vidx][bidx][ii] = denormalize(p, normalizer_ii, da, year_base) + preds_ens_b[vidx][bidx][ii] = denormalize(e, normalizer_ii, da, year_base) + + coords_mskd_l += [[dates_mskd, 90.-lats_mskd, lons_mskd] ] + + coords_b += [ [dates, 90. - lats, lons] + coords_mskd_l ] + + coords += [ coords_b ] fn = field_info[0] sources_out.append( [fn, sources_b]) - if is_predicted : - targets_out.append([fn, [[t.numpy(force=True) for t in t_v] for t_v in targets_b]]) - preds_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_mu_b]]) - ensembles_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_ens_b]]) - else : - targets_out.append( [fn, []]) - preds_out.append( [fn, []]) - ensembles_out.append( [fn, []]) - sources_dates_out.append( [c[0] for c in coords_b]) - sources_lats_out.append( [c[1] for c in coords_b]) - sources_lons_out.append( [c[2] for c in coords_b]) - if is_predicted : - targets_dates_out.append( [c[3] for c in coords_b]) - targets_lats_out.append( [c[4] for c in coords_b]) - targets_lons_out.append( [c[5] for c in coords_b]) - else : - targets_dates_out.append( [ ]) - targets_lats_out.append( [ ]) - targets_lons_out.append( [ ]) + targets_out.append([fn, [[t.numpy(force=True) for t in t_v] for t_v in targets_b]] if is_predicted else [fn, []]) + preds_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_mu_b]] if is_predicted else [fn, []] ) + ensembles_out.append( [fn, [[p.numpy(force=True) for p in p_v] for p_v in preds_ens_b]] if is_predicted else [fn, []] ) levels = [[np.array(l) for l in field[2]] for field in cf.fields] - write_BERT( cf.wandb_id, epoch, batch_idx, - levels, sources_out, - [sources_dates_out, sources_lats_out, sources_lons_out], - targets_out, [targets_dates_out, targets_lats_out, targets_lons_out], - preds_out, ensembles_out ) + write_BERT( cf.wandb_id, epoch, batch_idx, + levels, sources_out, targets_out, + preds_out, ensembles_out, coords ) + +###################################################### - def log_attention( self, epoch, bidx, log) : + def log_attention( self, epoch, bidx, attention) : '''Hook for logging: output attention maps.''' cf = self.cf - attention, token_infos = log - attn_dates_out, attn_lats_out, attn_lons_out = [ ], [ ], [ ] attn_out = [] for fidx, field_info in enumerate(cf.fields) : - # reconstruct coordinates - is_predicted = fidx in self.fields_prediction_idx - num_levels = len(field_info[2]) - num_tokens = field_info[3] - token_size = field_info[4] - lat_d_h, lon_d_h = int(np.floor(token_size[1]/2.)), int(np.floor(token_size[2]/2.)) - tinfos = token_infos[fidx].reshape( [-1, num_levels, *num_tokens, cf.size_token_info]) + + # coordinates coords_b = [] - - for tinfo in tinfos : - # use first vertical levels since a column is considered - res = tinfo[0,0,0,0,-1] - lats = np.arange(tinfo[0,0,0,0,4]-lat_d_h*res, tinfo[0,0,-1,0,4]+lat_d_h*res+0.001,res*token_size[1]) - if tinfo[0,0,0,-1,5] < tinfo[0,0,0,0,5] : - lons = np.remainder( np.arange( tinfo[0,0,0,0,5] - lon_d_h*res, - 360. + tinfo[0,0,0,-1,5] + lon_d_h*res + 0.001, res*token_size[2]), 360.) - else : - lons = np.arange(tinfo[0,0,0,0,5]-lon_d_h*res, tinfo[0,0,0,-1,5]+lon_d_h*res+0.001,res*token_size[2]) - - lats = [90.-lat for lat in lats] - lons = np.remainder( lons, 360.) - - dates = np.array([(utils.token_info_to_time(tinfo[0,t,0,0,:3])) for t in range(tinfo.shape[1])], dtype='datetime64[s]') + for bidx in range(batch_size): + dates = self.sources_info[bidx][0] + lats = 90. - self.sources_info[bidx][1] + lons = self.sources_info[bidx][2] coords_b += [ [dates, lats, lons] ] - if is_predicted: - attn_out.append([field_info[0], attention[fidx]]) - attn_dates_out.append([c[0] for c in coords_b]) - attn_lats_out.append( [c[1] for c in coords_b]) - attn_lons_out.append( [c[2] for c in coords_b]) - else: - attn_dates_out.append( [] ) - attn_lats_out.append( [] ) - attn_lons_out.append( [] ) - + is_predicted = fidx in self.fields_prediction_idx + attn_out.append([field_info[0], attention[fidx]] if is_predicted else [fn, []]) + levels = [[np.array(l) for l in field[2]] for field in cf.fields] write_attention(cf.wandb_id, epoch, - bidx, levels, attn_out, [attn_dates_out,attn_lats_out,attn_lons_out]) + bidx, levels, attn_out, coords_b ) diff --git a/atmorep/datasets/data_loader.py b/atmorep/datasets/data_loader.py deleted file mode 100644 index 4ddeae3..0000000 --- a/atmorep/datasets/data_loader.py +++ /dev/null @@ -1,161 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import pathlib -import numpy as np -import xarray as xr -from functools import partial - -import atmorep.utils.utils as utils -from atmorep.config.config import year_base -from atmorep.utils.utils import tokenize -from atmorep.datasets.file_io import grib_file_loader, netcdf_file_loader, bin_file_loader - -# TODO, TODO, TODO: replace with torch functonality -# import cv2 as cv - -class DataLoader: - - def __init__(self, path, file_shape, data_type = 'reanalysis', - file_format = 'grib', level_type = 'pl', - fname_base = '{}/{}/{}{}/{}_{}_y{}_m{}_{}{}', - smoothing = 0, - log_transform = False): - - self.path = path - self.data_type = data_type - self.file_format = file_format - self.file_shape = file_shape - self.fname_base = fname_base - self.smoothing = smoothing - self.log_transform = log_transform - - if 'grib' == file_format : - self.file_ext = '.grib' - self.file_loader = grib_file_loader - elif 'binary' == file_format : - self.file_ext = '_fp32.dat' - self.file_loader = bin_file_loader - elif 'netcdf4' == file_format : - self.file_ext = '.nc4' - self.file_loader = netcdf_file_loader - elif 'netcdf' == file_format : - self.file_ext = '.nc' - self.file_loader = netcdf_file_loader - else : - raise ValueError('Unsupported file format.') - - self.fname_base = fname_base + self.file_ext - - self.grib_index = { 'vorticity' : 'vo', 'divergence' : 'd', 'geopotential' : 'z', - 'orography' : 'z', 'temperature': 't', 'specific_humidity' : 'q', - 'mean_top_net_long_wave_radiation_flux' : 'mtnlwrf', - 'velocity_u' : 'u', 'velocity_v': 'v', 'velocity_z' : 'w', - 'total_precip' : 'tp', 'radar_precip' : 'yw_hourly', - 't2m' : 't_2m', 'u_10m' : 'u_10m', 'v_10m' : 'v_10m', } - - def get_field( self, year, month, field, level_type, vl, - token_size = [-1, -1], t_pad = [-1, -1, 1]): - - t_srate = t_pad[2] - data_ym = torch.zeros( (0, self.file_shape[1], self.file_shape[2])) - - # pre-fill fixed values - # fname_base = self.fname_base.format( self.path, self.data_type, field, level_type, vl, - fname_base = self.fname_base.format( self.path, field, level_type, vl, - self.data_type, field, {},{},{},{}) - - # padding pre - if t_pad[0] > 0 : - if month > 1: - month_p = str(month-1).zfill(2) - days_month = utils.days_in_month( year, month-1) - fname = fname_base.format( year, month_p, level_type, vl) - else: - assert(year >= year_base) - year_p = str(year-1).zfill(2) - days_month = utils.days_in_month( year, 12) - fname = fname_base.format( year-1, 12, level_type, vl) - x = self.file_loader( fname, self.grib_index[field], [t_pad[0], 0, t_srate], - days_month ) - - data_ym = torch.cat((data_ym,x),0) - - # data - fname = fname_base.format( year, str(month).zfill(2), level_type, vl) - days_month = utils.days_in_month( year, month) - x = self.file_loader(fname, self.grib_index[field], [0, 0, t_srate], days_month) - - data_ym = torch.cat((data_ym,x),0) - - # padding post - if t_pad[1] > 0 : - if month > 1: - month_p = str(month+1).zfill(2) - days_month = utils.days_in_month( year, month+1) - fname = fname_base.format( year, month_p, level_type, vl) - else: - assert(year >= year_base) - year_p = str(year+1).zfill(2) - days_month = utils.days_in_month( year+1, 12) - fname = fname_base.format( year_p, 12, level_type, vl) - x = self.file_loader( fname, self.grib_index[field], [0, t_pad[1], t_srate], - days_month) - - data_ym = torch.cat((data_ym,x),0) - - if self.smoothing > 0 : - sm = self.smoothing - mask_nan = torch.isnan( data_ym) - data_ym[ mask_nan ] = 0. - blur = partial( cv.blur, ksize=(sm,sm), borderType=cv.BORDER_REFLECT_101) - data_ym = [torch.from_numpy( blur( data_ym[k].numpy()) ).unsqueeze(0) - for k in range(data_ym.shape[0])] - data_ym = torch.cat( data_ym, 0) - data_ym[ mask_nan ] = torch.nan - - # tokenize - data_ym = tokenize( data_ym, token_size) - - return data_ym - - def get_single_field( self, years_months, field = 'vorticity', level_type = 'pl', vl = 975, - token_size = [-1, -1], t_pad = [-1, -1, 1]): - - data_field = [] - for year, month in years_months : - data_field.append( self.get_field( year, month, field, level_type, vl, token_size, t_pad)) - - return data_field - - def get_static_field( self, field, token_size = [-1, -1]): - - #support for static fields from other data types - data_type = self.data_type - f = self.path + '/' + data_type + '/static/' +self.data_type + '_' + field + self.file_ext - - x = self.file_loader(f, self.grib_index[field], static=True) - - if self.smoothing > 0 : - sm = self.smoothing - blur = partial( cv.blur, ksize=(sm,sm), borderType=cv.BORDER_REFLECT_101) - # x = torch.from_numpy( cv.blur( x.numpy(), (self.smoothing,self.smoothing))) - x = torch.from_numpy( blur( x.numpy() )) - - x = tokenize( x, token_size) - - return x diff --git a/atmorep/datasets/data_writer.py b/atmorep/datasets/data_writer.py index 350c7b4..a63c3d8 100644 --- a/atmorep/datasets/data_writer.py +++ b/atmorep/datasets/data_writer.py @@ -17,38 +17,41 @@ import numpy as np import xarray as xr import zarr -import code -import datetime import atmorep.config.config as config +def write_item(ds_field, name_idx, data, levels, coords, name = 'sample' ): + ds_batch_item = ds_field.create_group( f'{name}={name_idx:05d}' ) + ds_batch_item.create_dataset( 'data', data=data) + ds_batch_item.create_dataset( 'ml', data=levels) + ds_batch_item.create_dataset( 'datetime', data=coords[0].astype('datetime64[ns]')) + ds_batch_item.create_dataset( 'lat', data=np.array(coords[1]).astype(np.float32)) + ds_batch_item.create_dataset( 'lon', data=np.array(coords[2]).astype(np.float32)) + return ds_batch_item + #################################################################################################### -def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, - targets, targets_coords, - preds, ensembles, - zarr_store_type = 'ZipStore' ) : +def write_forecast( model_id, epoch, batch_idx, levels, sources, + targets, preds, ensembles, coords, + zarr_store_type = 'ZipStore' ) : ''' sources : num_fields x [field name , data] targets : preds, ensemble share coords with targets ''' - + sources_coords = [[c[:3] for c in coord_field ] for coord_field in coords] + targets_coords = [[[c[-1], c[1], c[2]] for c in coord_field ] for coord_field in coords] fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' zarr_store = getattr( zarr, zarr_store_type) store_source = zarr_store( fname.format( 'source')) exp_source = zarr.group(store=store_source) + for fidx, field in enumerate(sources) : ds_field = exp_source.require_group( f'{field[0]}') batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=sources_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=sources_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=sources_coords[2][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, sources_coords[fidx][bidx]) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -58,12 +61,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -73,12 +71,7 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -88,26 +81,23 @@ def write_forecast( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels) - ds_batch_item.create_dataset( 'datetime', data=targets_coords[0][bidx]) - ds_batch_item.create_dataset( 'lat', data=targets_coords[1][bidx]) - ds_batch_item.create_dataset( 'lon', data=targets_coords[2][bidx]) + write_item(ds_field, sample, field[1][bidx], levels, targets_coords[fidx][bidx]) store_ens.close() #################################################################################################### -def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, - targets, targets_coords, - preds, ensembles, - zarr_store_type = 'ZipStore' ) : +def write_BERT( model_id, epoch, batch_idx, levels, sources, + targets, preds, ensembles, coords, + zarr_store_type = 'ZipStore' ) : + ''' sources : num_fields x [field name , data] targets : preds, ensemble share coords with targets ''' - # fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch}.zarr' + sources_coords = [[c[:3] for c in coord_field ] for coord_field in coords] + targets_coords = [[c[3:] for c in coord_field ] for coord_field in coords] + fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' zarr_store = getattr( zarr, zarr_store_type) @@ -119,12 +109,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, batch_size = field[1].shape[0] for bidx in range( field[1].shape[0]) : sample = batch_idx * batch_size + bidx - ds_batch_item = ds_field.create_group( f'sample={sample:05d}' ) - ds_batch_item.create_dataset( 'data', data=field[1][bidx]) - ds_batch_item.create_dataset( 'ml', data=levels[fidx]) - ds_batch_item.create_dataset( 'datetime', data=sources_coords[0][0][bidx]) - ds_batch_item.create_dataset( 'lat', data=sources_coords[1][0][bidx]) - ds_batch_item.create_dataset( 'lon', data=sources_coords[2][0][bidx]) + write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] ) store_source.close() store_target = zarr_store( fname.format( 'target')) @@ -138,12 +123,7 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, sample = batch_idx * batch_size + bidx ds_target_b = ds_field.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_target_b_l = ds_target_b.require_group( f'ml={levels[fidx][vidx]}') - ds_target_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - ds_target_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_target_b_l.create_dataset( 'datetime', data=targets_coords[0][fidx][bidx][vidx]) - ds_target_b_l.create_dataset( 'lat', data=targets_coords[1][fidx][bidx][vidx]) - ds_target_b_l.create_dataset( 'lon', data=targets_coords[2][fidx][bidx][vidx]) + write_item(ds_target_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], targets_coords[fidx][bidx][vidx], name = 'ml' ) store_target.close() store_pred = zarr_store( fname.format( 'pred')) @@ -157,13 +137,8 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, sample = batch_idx * batch_size + bidx ds_pred_b = ds_pred.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_pred_b_l = ds_pred_b.create_group( f'ml={levels[fidx][vidx]}') - ds_pred_b_l.create_dataset( 'data', data - =field[1][vidx][bidx]) - ds_pred_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_pred_b_l.create_dataset( 'datetime', data=targets_coords[0][fidx][bidx][vidx]) - ds_pred_b_l.create_dataset( 'lat', data=targets_coords[1][fidx][bidx][vidx]) - ds_pred_b_l.create_dataset( 'lon', data=targets_coords[2][fidx][bidx][vidx]) + write_item(ds_pred_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], + targets_coords[fidx][bidx][vidx], name = 'ml' ) store_pred.close() store_ens = zarr_store( fname.format( 'ens')) @@ -177,16 +152,12 @@ def write_BERT( model_id, epoch, batch_idx, levels, sources, sources_coords, sample = batch_idx * batch_size + bidx ds_ens_b = ds_ens.create_group( f'sample={sample:05d}') for vidx in range(len(levels[fidx])) : - ds_ens_b_l = ds_ens_b.create_group( f'ml={levels[fidx][vidx]}') - ds_ens_b_l.create_dataset( 'data', data=field[1][vidx][bidx]) - ds_ens_b_l.create_dataset( 'ml', data=levels[fidx][vidx]) - ds_ens_b_l.create_dataset( 'datetime', data=targets_coords[0][fidx][bidx][vidx]) - ds_ens_b_l.create_dataset( 'lat', data=targets_coords[1][fidx][bidx][vidx]) - ds_ens_b_l.create_dataset( 'lon', data=targets_coords[2][fidx][bidx][vidx]) + write_item(ds_ens_b, levels[fidx][vidx], field[1][vidx][bidx], levels[fidx][vidx], + targets_coords[fidx][bidx][vidx], name = 'ml' ) store_ens.close() #################################################################################################### -def write_attention(model_id, epoch, batch_idx, levels, attn, attn_coords, zarr_store_type = 'ZipStore' ) : +def write_attention(model_id, epoch, batch_idx, levels, attn, coords, zarr_store_type = 'ZipStore' ) : fname = f'{config.path_results}/id{model_id}/results_id{model_id}_epoch{epoch:05d}' + '_{}.zarr' zarr_store = getattr( zarr, zarr_store_type) @@ -200,9 +171,9 @@ def write_attention(model_id, epoch, batch_idx, levels, attn, attn_coords, zarr_ for lidx, atts_f_l in enumerate(atts_f[1]) : # layer in the network ds_f_l = ds_field_b.require_group( f'layer={lidx:05d}') ds_f_l.create_dataset( 'ml', data=levels[fidx]) - ds_f_l.create_dataset( 'datetime', data=attn_coords[0][fidx]) - ds_f_l.create_dataset( 'lat', data=attn_coords[1][fidx]) - ds_f_l.create_dataset( 'lon', data=attn_coords[2][fidx]) + ds_f_l.create_dataset( 'datetime', data=coords[0][fidx]) + ds_f_l.create_dataset( 'lat', data=coords[1][fidx]) + ds_f_l.create_dataset( 'lon', data=coords[2][fidx]) ds_f_l_h = ds_f_l.require_group('heads') for hidx, atts_f_l_head in enumerate(atts_f_l) : # number of attention head if atts_f_l_head != None : diff --git a/atmorep/datasets/dynamic_field_level.py b/atmorep/datasets/dynamic_field_level.py deleted file mode 100644 index 600a181..0000000 --- a/atmorep/datasets/dynamic_field_level.py +++ /dev/null @@ -1,242 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import pathlib -import numpy as np -import math -import os, sys -import time -import itertools -import gc -import code -# code.interact(local=locals()) - -import atmorep.utils.utils as utils -from atmorep.utils.utils import shape_to_str -from atmorep.utils.utils import days_until_month_in_year -from atmorep.utils.utils import days_in_month - -from atmorep.datasets.data_loader import DataLoader -from atmorep.datasets.normalizer_global import NormalizerGlobal -from atmorep.datasets.normalizer_local import NormalizerLocal - -class DynamicFieldLevel() : - - ################################################### - def __init__( self, file_path, years_data, field_info, - batch_size, data_type = 'era5', - file_shape = [-1, 721, 1440], file_geo_range = [[-90.,90.], [0.,360.]], - num_tokens = [3, 9, 9], token_size = [1, 9, 9], - level_type = 'pl', vl = 975, time_sampling = 1, - smoothing = 0, file_format = 'grib', corr_type = 'local', - log_transform_data = False ) : - ''' - Data set for single dynamic field at a single vertical level - ''' - - self.years_data = years_data - self.field_info = field_info - self.file_path = file_path - self.file_shape = file_shape - self.file_format = file_format - self.level_type = level_type - self.vl = vl - self.time_sampling = time_sampling - self.smoothing = smoothing - self.corr_type = corr_type - self.log_transform_data = log_transform_data - - self.years_months = [] - - # work internally with mathematical latitude coordinates in [0,180] - self.file_geo_range = [ -np.array(file_geo_range[0]) + 90. , np.array(file_geo_range[1]) ] - # enforce that georange is North to South - self.geo_range_flipped = False - if self.file_geo_range[0][0] > self.file_geo_range[0][1] : - self.file_geo_range[0] = np.flip( self.file_geo_range[0]) - self.geo_range_flipped = True - self.is_global = 0. == self.file_geo_range[0][0] and 0. == self.file_geo_range[1][0] \ - and 180. == self.file_geo_range[0][1] and 360. == self.file_geo_range[1][1] - - # resolution - # TODO: non-uniform resolution in latitude and longitude - self.res = (file_geo_range[1][1] - file_geo_range[1][0]) - self.res /= file_shape[2] if self.is_global else (file_shape[2]-1) - - self.batch_size = batch_size - self.num_tokens = torch.tensor( num_tokens, dtype=torch.int) - rem1 = (num_tokens[1]*token_size[1]) % 2 - rem2 = (num_tokens[2]*token_size[2]) % 2 - t1 = num_tokens[1]*token_size[1] - t2 = num_tokens[2]*token_size[2] - self.grid_delta = [ [int((t1+rem1)/2), int(t1/2)], [int((t2+rem2)/2), int(t2/2)] ] - assert( num_tokens[1] < file_shape[1]) - assert( num_tokens[2] < file_shape[2]) - self.tok_size = token_size - - self.data_field = None - - if self.corr_type == 'global' : - self.normalizer = NormalizerGlobal( field_info, vl, self.file_shape, data_type) - else : - self.normalizer = NormalizerLocal( field_info, vl, self.file_shape, data_type) - - self.loader = DataLoader( self.file_path, self.file_shape, data_type, - file_format = self.file_format, level_type = self.level_type, - smoothing = self.smoothing, log_transform=self.log_transform_data) - - ################################################### - def load_data( self, years_months, idxs_perm, batch_size = None) : - - self.idxs_perm = idxs_perm.copy() - - # nothing to be loaded - if set(years_months) in set(self.years_months): - return - - self.years_months = years_months - - if batch_size : - self.batch_size = batch_size - loader = self.loader - - self.files_offset_days = [] - for year, month in self.years_months : - self.files_offset_days.append( days_until_month_in_year( year, month) ) - - # load data - # self.data_field is a list of lists of torch tensors - # [i] : year/month - # [i][j] : field per year/month - # [i][j] : len_data_per_month x num_tokens_lat x num_tokens_lon x token_size x token_size - # this ensures coherence in the data access - del self.data_field - gc.collect() - self.data_field = loader.get_single_field( self.years_months, self.field_info[0], - self.level_type, self.vl, [-1, -1], - [self.num_tokens[0] * self.tok_size[0], 0, - self.time_sampling]) - - # apply normalization and log-transform for each year-month data - for j in range( len(self.data_field) ) : - - if self.corr_type == 'local' : - coords = [ np.linspace( 0., 180., num=180*4+1, endpoint=True), - np.linspace( 0., 360., num=360*4, endpoint=False) ] - else : - coords = None - - (year, month) = self.years_months[j] - self.data_field[j] = self.normalizer.normalize( year, month, self.data_field[j], coords) - - # basics statistics - print( 'INFO:: data stats {} : {} / {}'.format( self.field_info[0], - self.data_field[j].mean(), - self.data_field[j].std()) ) - - ############################################### - def __getitem__( self, bidx) : - - tn = self.grid_delta - num_tokens = self.num_tokens - tok_size = self.tok_size - tnt = self.num_tokens[0] * self.tok_size[0] - cat = torch.cat - geor = self.file_geo_range - - idx = bidx * self.batch_size - - # physical fields - patch_s = [nt*ts for nt,ts in zip(self.num_tokens,self.tok_size)] - x = torch.zeros( self.batch_size, patch_s[0], patch_s[1], patch_s[2] ) - cids = torch.zeros( self.batch_size, num_tokens.prod(), 8) - - # offset from previous month to be able to sample all time slices in current one - offset_t = int(num_tokens[0] * tok_size[0]) - # 721 etc have grid points at the beginning and end which leads to incorrect results in places - file_shape = np.array(self.file_shape) - file_shape = file_shape-1 if not self.is_global else np.array(self.file_shape)-np.array([0,1,0]) - - # for all items in batch - for jj in range( self.batch_size) : - - i_ym = int(self.idxs_perm[idx][0]) - # perform a deep copy to not overwrite cid for other fields - cid = np.array( self.idxs_perm[idx][1:]).copy() - cid_orig = cid.copy() - - # map to grid coordinates (first map to normalized [0,1] coords and then to grid coords) - cid[2] = np.mod( cid[2], 360.) if self.is_global else cid[2] - assert cid[1] >= geor[0][0] and cid[1] <= geor[0][1], 'invalid latitude for geo_range' - cid[1] = ( (cid[1] - geor[0][0]) / (geor[0][1] - geor[0][0]) ) * file_shape[1] - cid[2] = ( ((cid[2]) - geor[1][0]) / (geor[1][1] - geor[1][0]) ) * file_shape[2] - assert cid[1] >= 0 and cid[1] < self.file_shape[1] - assert cid[2] >= 0 and cid[2] < self.file_shape[2] - - # alignment when parent field has different resolution than this field - cid = np.round( cid).astype( np.int64) - - ran_t = list( range( cid[0]-tnt+1 + offset_t, cid[0]+1 + offset_t)) - if any(np.array(ran_t) >= self.data_field[i_ym].shape[0]) : - print( '{} : {} :: {}'.format( self.field_info[0], self.years_months[i_ym], ran_t )) - - # periodic boundary conditions around equator - ran_lon = np.array( list( range( cid[2]-tn[1][0], cid[2]+tn[1][1]))) - if self.is_global : - ran_lon = np.mod( ran_lon, self.file_shape[2]) - else : - # sanity check for indices for files with local window - # this should be controlled by georange_sampling for sampling - assert all( ran_lon >= 0) and all( ran_lon < self.file_shape[2]) - - ran_lat = np.array( list( range( cid[1]-tn[0][0], cid[1]+tn[0][1]))) - assert all( ran_lat >= 0) and all( ran_lat < self.file_shape[1]) - - # current data - # if self.geo_range_flipped : - # print( '{} : {} / {}'.format( self.field_info[0], ran_lat, ran_lon) ) - if np.max(ran_t) >= self.data_field[i_ym].shape[0] : - print( 'WARNING: {} : {} :: {}'.format( self.field_info[0], ran_t, self.years_months[i_ym]) ) - x[jj] = np.take( np.take( self.data_field[i_ym][ran_t], ran_lat, 1), ran_lon, 2) - - # set per token information - assert self.time_sampling == 1 - ran_tt = np.flip( np.arange( cid[0], cid[0]-tnt, -tok_size[0])) - years = self.years_months[i_ym][0] * np.ones( ran_tt.shape) - days_in_year = self.files_offset_days[i_ym] + (ran_tt / 24.) - # wrap year around - mask = days_in_year < 0 - years[ mask ] -= 1 - days_in_year[ mask ] += 365 - hours = np.mod( ran_tt, 24) - lats = ran_lat[int(tok_size[1]/2)::tok_size[1]] * self.res + self.file_geo_range[0][0] - lons = ran_lon[int(tok_size[2]/2)::tok_size[2]] * self.res + self.file_geo_range[1][0] - stencil = torch.tensor(list(itertools.product(lats,lons))) - tstencil = torch.tensor( [ [y, d, h, self.vl] for y,d,h in zip( years, days_in_year, hours)], - dtype=torch.float) - txlist = list( itertools.product( tstencil, stencil)) - cids[jj,:,:6] = torch.cat( [torch.cat(tx).unsqueeze(0) for tx in txlist], 0) - cids[jj,:,6] = self.vl - cids[jj,:,7] = self.res - - idx += 1 - - return (x, cids) - - ################################################### - def __len__(self): - return int(self.idxs_perm.shape[0] / self.batch_size) diff --git a/atmorep/datasets/file_io.py b/atmorep/datasets/file_io.py deleted file mode 100644 index 1974044..0000000 --- a/atmorep/datasets/file_io.py +++ /dev/null @@ -1,85 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import numpy as np -import xarray as xr - -#################################################################################################### - -def netcdf_file_loader(fname, field, time_padding = [0,0,1], days_in_month = 0, static=False) : - - ds = xr.open_dataset(fname, engine='netcdf4')[field] - - if not static: - # TODO: test that only time_padding[0] *or* time_padding[1] != 0 - if time_padding[0] != 0 : - ds = ds[-time_padding[0]*time_padding[2] : ] - elif time_padding[1] != 0 : - ds = ds[ : time_padding[1]*time_padding[2]] - if time_padding[2] > 1 : - ds = ds[::time_padding[2]] - - x = torch.from_numpy(np.array( ds, dtype=np.float32)) - ds.close() - - return x - -#################################################################################################### - -def grib_file_loader(fname, field, time_padding = [0,0,1], days_in_month = 0, static=False) : - - ds = xr.open_dataset(fname, engine='cfgrib', - backend_kwargs={'time_dims':('valid_time','indexing_time')})[field] - - # work-around for bug in download where for every month 31 days have been downloaded - if days_in_month > 0 : - ds = ds[:24*days_in_month] - - if not static: - # TODO: test that only time_padding[0] *or* time_padding[1] != 0 - if time_padding[0] != 0 : - ds = ds[-time_padding[0]*time_padding[2] : ] - elif time_padding[1] != 0 : - ds = ds[ : time_padding[1]*time_padding[2]] - if time_padding[2] > 1 : - ds = ds[::time_padding[2]] - - x = torch.from_numpy(np.array(ds, dtype=np.float32)) - ds.close() - - # assume grib files are clean and NaNs are introduced through handling of "missing values" - if np.isnan( x).any() : - x_shape = x.shape - x = x.flatten() - x[np.argwhere( np.isnan( x))] = 9999. - x = np.reshape( x, x_shape) - - return x - -#################################################################################################### - -def bin_file_loader( fname, field, time_padding = [0,0,1], static=False, file_shape = (-1, 721, 1440)) : - - ds = np.fromfile(fname, dtype=np.float32) - print("INFO:: reshaping binary file into {}".format(file_shape)) - if not static: - ds = np.reshape(ds, file_shape) - ds = ds[time_padding[0]:(ds.shape[0] - time_padding[0])] - else: - ds = np.reshape(ds, file_shape[1:]) - x = torch.from_numpy(ds) - return x diff --git a/atmorep/datasets/multifield_data_sampler.py b/atmorep/datasets/multifield_data_sampler.py index 7e0c629..5e2c6f8 100644 --- a/atmorep/datasets/multifield_data_sampler.py +++ b/atmorep/datasets/multifield_data_sampler.py @@ -16,328 +16,238 @@ import torch import numpy as np -import math -import itertools +import zarr +import pandas as pd +from datetime import datetime +import time +import os import code -# code.interact(local=locals()) -from atmorep.datasets.dynamic_field_level import DynamicFieldLevel -from atmorep.datasets.static_field import StaticField - -from atmorep.utils.utils import days_until_month_in_year -from atmorep.utils.utils import days_in_month - -import atmorep.config.config as config +# from atmorep.datasets.normalizer_global import NormalizerGlobal +# from atmorep.datasets.normalizer_local import NormalizerLocal +from atmorep.datasets.normalizer import normalize +from atmorep.utils.utils import tokenize, get_weights class MultifieldDataSampler( torch.utils.data.IterableDataset): ################################################### - def __init__( self, file_path, years_data, fields, batch_size, - num_t_samples, num_patches_per_t, num_load, pre_batch, - rng_seed = None, file_shape = (-1, 721, 1440), - level_type = 'ml', time_sampling = 1, - smoothing = 0, file_format = 'grib', month = None, lat_sampling_weighted = True, - geo_range = [[-90.,90.], [0.,360.]], - fields_targets = [], pre_batch_targets = None - ) : + def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size, + num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False, compute_weights = False, + fields_targets = None, pre_batch_targets = None ) : ''' Data set for single dynamic field at an arbitrary number of vertical levels + + nsize : neighborhood in (tsteps, deg_lat, deg_lon) ''' super( MultifieldDataSampler).__init__() self.fields = fields self.batch_size = batch_size - + self.n_size = n_size + self.num_samples = num_samples + self.with_source_idxs = with_source_idxs + self.compute_weights = compute_weights + self.with_shuffle = with_shuffle self.pre_batch = pre_batch + + assert os.path.exists(file_path), f"File path {file_path} does not exist" + self.ds = zarr.open( file_path) + + self.ds_global = self.ds.attrs['is_global'] + + self.lats = np.array( self.ds['lats']) + self.lons = np.array( self.ds['lons']) + + sh = self.ds['data'].shape + st = self.ds['time'].shape + self.ds_len = st[0] + print( f'self.ds[\'data\'] : {sh} :: {st}') + print( f'self.lats : {self.lats.shape}', flush=True) + print( f'self.lons : {self.lons.shape}', flush=True) + self.fields_idxs = [] - self.years_data = years_data self.time_sampling = time_sampling - self.month = month - self.range_lat = 90. - np.array( geo_range[0]) - self.range_lon = np.array( geo_range[1]) - self.geo_range = geo_range - - # order North to South - self.range_lat = np.flip(self.range_lat) if self.range_lat[1] < self.range_lat[0] \ - else self.range_lat - - # prepare range_lat and range_lon for sampling - self.is_global = 0 == self.range_lat[0] and self.range_lon[0] == 0. \ - and 180. == self.range_lat[1] and 360. == self.range_lon[1] + self.range_lat = np.array( self.lats[ [0,-1] ]) + self.range_lon = np.array( self.lons[ [0,-1] ]) + self.res = np.array(self.ds.attrs['res']) + self.year_base = self.ds['time'][0].astype(datetime).year + + # ensure neighborhood does not exceed domain (either at pole or for finite domains) + self.range_lat += np.array([n_size[1] / 2., -n_size[1] / 2.]) + # lon: no change for periodic case + if self.ds_global < 1.: + self.range_lon += np.array([n_size[2]/2., -n_size[2]/2.]) - # TODO: this assumes file_shape is set correctly and not just per field and it defines a - # reference grid, likely has to be the coarsest - self.res = 360. / file_shape[2] + # data normalizers + self.normalizers = [] + for ifield, field_info in enumerate(fields) : + corr_type = 'global' if len(field_info) <= 6 else field_info[6] + nf_name = 'global_norm' if corr_type == 'global' else 'norm' + self.normalizers.append( [] ) + for vl in field_info[2]: + if vl == 0: + field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) + n_name = f'normalization/{nf_name}_sfc' + self.normalizers[ifield] += [self.ds[n_name].oindex[ :, :, field_idx]] + else: + vl_idx = self.ds.attrs['levels'].index(vl) + field_idx = self.ds.attrs['fields'].index( field_info[0]) + n_name = f'normalization/{nf_name}' + self.normalizers[ifield] += [self.ds[n_name].oindex[ :, :, field_idx, vl_idx]] - # avoid wrap around at poles - pole_offset = np.ceil(fields[0][3][1] * fields[0][4][1] / 2) * self.res - self.range_lat[0] = pole_offset if self.range_lat[0] < pole_offset else self.range_lat[0] - self.range_lat[1] =180.-pole_offset if 180.-self.range_lat[1] South ordering - # shrink so that cookie cutting based on sampling does not exceed domain if it is not global - if not self.is_global : - # TODO: check that field data is consistent and covers the same spatial domain - # TODO: code below assumes that fields[0] is global - # TODO: code below does not handle anisotropic grids - finfo = self.fields[0] - # ensure that delta is a multiple of the coarse grid resolution - ngrid1 = finfo[3][1] * finfo[4][1] - ngrid2 = finfo[3][2] * finfo[4][2] - delta1 = 0.5 * self.res * (ngrid1-1 if ngrid1 % 2==0 else ngrid1+1) - delta2 = 0.5 * self.res * (ngrid2-1 if ngrid2 % 2==0 else ngrid2+1) - self.range_lat += np.array([delta1, -delta1]) - self.range_lon += np.array([delta2, -delta2]) - - # ensure all data loaders use same rng_seed and hence generate consistent data - if not rng_seed : - rng_seed = np.random.randint( 0, 100000, 1)[0] - self.rng = np.random.default_rng( rng_seed) - - # create (source) fields - self.datasets = self.create_loaders( fields) - - # create (target) fields - self.datasets_targets = self.create_loaders( fields_targets) - - ################################################### - def create_loaders( self, fields ) : - - datasets = [] - for field_idx, field_info in enumerate(fields) : - - datasets.append( []) - - # extract field info - (vls, num_tokens, token_size) = field_info[2:5] - - if len(field_info) > 6 : - corr_type = field_info[6] - else: - corr_type = 'global' - - smoothing = self.smoothing - log_transform_data = False - if len(field_info) > 7 : - (data_type, file_shape, file_geo_range, file_format) = field_info[7][:4] - if len( field_info[7]) > 6 : - smoothing = field_info[7][6] - print( '{} : smoothing = {}'.format( field_info[0], smoothing) ) - if len( field_info[7]) > 7 : - log_transform_data = field_info[7][7] - print( '{} : log_transform_data = {}'.format( field_info[0], log_transform_data) ) - else : - data_type = 'era5' - file_format = self.file_format - file_shape = self.file_shape - file_geo_range = [[90.,-90.], [0.,360.]] - - # static fields - if 0 == field_info[1][0] : - datasets[-1].append( StaticField( self.file_path, field_info, self.batch_size, data_type, - file_shape, file_geo_range, - num_tokens, token_size, smoothing, file_format, corr_type) ) - - # dynamic fields - elif 1 == field_info[1][0] : - for vlevel in vls : - datasets[-1].append( DynamicFieldLevel( self.file_path, self.years_data, field_info, - self.batch_size, data_type, - file_shape, file_geo_range, - num_tokens, token_size, - self.level_type, vlevel, self.time_sampling, - smoothing, file_format, corr_type, - log_transform_data ) ) - - else : - assert False + # extract indices for selected years + self.times = pd.DatetimeIndex( self.ds['time']) + idxs_years = self.times.year == years[0] + for year in years[1:] : + idxs_years = np.logical_or( idxs_years, self.times.year == year) + self.idxs_years = np.where( idxs_years)[0] - return datasets + self.num_samples = min( self.num_samples, self.idxs_years.shape[0]) ################################################### def shuffle( self) : - # ensure that different parallel loaders create independent random shuffles - delta = torch.randint( 0, 100000, (1,)).item() - self.rng.bit_generator.advance( delta) - - self.idxs_perm = np.zeros( (0, 4), dtype=np.int64) - - # latitude, first map to mathematical lat coords in [0,180.], then to [0,pi] then - # to z-value in [-1,1] - if self.lat_sampling_weighted : - lat_r = np.cos( self.range_lat/180. * np.pi) - else : - lat_r = self.range_lat - - # 1.00001 is a fudge factor since np.round(*.5) leads to flooring instead of proper up-rounding - res_inv = 1.0 / self.res * 1.00001 + worker_info = torch.utils.data.get_worker_info() + rng_seed = None + if worker_info is not None : + rng_seed = int(time.time()) // (worker_info.id+1) + worker_info.id - # loop over individual data year-month items - for i_ym in range( len(self.years_months)) : + rng = np.random.default_rng( rng_seed) + self.idxs_perm_t = rng.permutation( self.idxs_years)[ : self.num_samples // self.batch_size] - ym = self.years_months[i_ym] - - # ensure a constant size of work load of data loader independent of the month length - # factor of 128 is a fudge parameter to ensure that mod-ing leads to sufficiently - # random wrap-around (with 1 instead of 128 there is clustering on the first days) - hours_in_day = int( 24 / self.time_sampling) - time_slices = 128 * 31 * hours_in_day - time_slices_i_ym = hours_in_day * days_in_month( ym[0], ym[1]) - idxs_perm_temp = np.mod(self.rng.permutation(time_slices), time_slices_i_ym) - # fixed number of time samples independent of length of month - idxs_perm_temp = idxs_perm_temp[:self.num_t_samples] - idxs_perm = np.zeros( (self.num_patches_per_t *idxs_perm_temp.shape[0],4) ) - - # split up into file index and local index - idx = 0 - for it in idxs_perm_temp : - - idx_patches = self.rng.random( (self.num_patches_per_t, 2) ) - # for jj in idx_patches : - for jj in idx_patches : - # area consistent sampling on the sphere (with less patches close to the pole) - # see https://graphics.stanford.edu/courses/cs448-97-fall/notes.html , Lecture 7 - # for area preserving sampling of the sphere - # py \in [0,180], px \in [0,360] (possibly with negative values for lon) - if self.lat_sampling_weighted : - py = ((np.arccos(lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) / np.pi) * 180.) - else : - py = (lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) - px = jj[1] * (self.range_lon[1] - self.range_lon[0]) + self.range_lon[0] - - # align with grid - py = self.res * np.round( py * res_inv) - px = self.res * np.round( px * res_inv) - - idxs_perm[idx] = np.array( [i_ym, it, py, px]) - idx = idx + 1 - - self.idxs_perm = np.concatenate( (self.idxs_perm, idxs_perm[:idx])) - - # shuffle again to avoid clustering of patches by loop over idx_patches above - self.idxs_perm = self.idxs_perm[self.rng.permutation(self.idxs_perm.shape[0])] - self.idxs_perm = self.idxs_perm[self.rng.permutation(self.idxs_perm.shape[0])] - # restrict to multiples of batch size - lenbatch = int(math.floor(self.idxs_perm.shape[0] / self.batch_size)) * self.batch_size - self.idxs_perm = self.idxs_perm[:lenbatch] - # # DEBUG - # print( 'self.idxs_perm.shape = {}'.format(self.idxs_perm.shape )) - # rank = torch.distributed.get_rank() - # fname = 'idxs_perm_rank{}_{}.dat'.format( rank, shape_to_str( self.idxs_perm.shape)) - # self.idxs_perm.tofile( fname) + lats = rng.random(self.num_samples) * (self.range_lat[1] - self.range_lat[0]) +self.range_lat[0] + lons = rng.random(self.num_samples) * (self.range_lon[1] - self.range_lon[0]) +self.range_lon[0] - ################################################### - def set_full_time_range( self) : - - self.idxs_perm = np.zeros( (0, 4), dtype=np.int64) - - # latitude, first map to mathematical lat coords in [0,180.], then to [0,pi] then - # to z-value in [-1,1] - if self.lat_sampling_weighted : - lat_r = np.cos( self.range_lat/180. * np.pi) - else : - lat_r = self.range_lat - - # 1.00001 is a fudge factor since np.round(*.5) leads to flooring instead of proper up-rounding + # align with grid res_inv = 1.0 / self.res * 1.00001 + lats = self.res[0] * np.round( lats * res_inv[0]) + lons = self.res[1] * np.round( lons * res_inv[1]) - # loop over individual data year-month items - for i_ym in range( len(self.years_months)) : - - ym = self.years_months[i_ym] - - hours_in_day = int( 24 / self.time_sampling) - idxs_perm_temp = np.arange( hours_in_day * days_in_month( ym[0], ym[1])) - idxs_perm = np.zeros( (self.num_patches_per_t *idxs_perm_temp.shape[0],4) ) - - # split up into file index and local index - idx = 0 - for it in idxs_perm_temp : - - idx_patches = self.rng.random( (self.num_patches_per_t, 2) ) - for jj in idx_patches : - # area consistent sampling on the sphere (with less patches close to the pole) - # see https://graphics.stanford.edu/courses/cs448-97-fall/notes.html , Lecture 7 - # for area preserving sampling of the sphere - # py \in [0,180], px \in [0,360] (possibly with negative values for lon) - if self.lat_sampling_weighted : - py = ((np.arccos(lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) / np.pi) * 180.) - else : - py = (lat_r[0] + (lat_r[1]-lat_r[0]) * jj[0]) - px = jj[1] * (self.range_lon[1] - self.range_lon[0]) + self.range_lon[0] - - # align with grid - py = self.res * np.round( py * res_inv) - px = self.res * np.round( px * res_inv) + self.idxs_perm = np.stack( [lats, lons], axis=1) - idxs_perm[idx] = np.array( [i_ym, it, py, px]) - idx = idx + 1 + ################################################### + def __iter__(self): - self.idxs_perm = np.concatenate( (self.idxs_perm, idxs_perm[:idx])) + if self.with_shuffle : + self.shuffle() - # shuffle again to avoid clustering of patches by loop over idx_patches above - self.idxs_perm = self.idxs_perm[self.rng.permutation(self.idxs_perm.shape[0])] - # restrict to multiples of batch size - lenbatch = int(math.floor(self.idxs_perm.shape[0] / self.batch_size)) * self.batch_size - self.idxs_perm = self.idxs_perm[:lenbatch] + lats, lons = self.lats, self.lons + ts, n_size = self.time_sampling, self.n_size + ns_2 = np.array(self.n_size) / 2. + res = self.res - # # DEBUG - # print( 'self.idxs_perm.shape = {}'.format(self.idxs_perm.shape )) - # fname = 'idxs_perm_{}_{}.dat'.format( self.epoch_counter, shape_to_str( self.idxs_perm.shape)) - # self.idxs_perm.tofile( fname) + iter_start, iter_end = self.worker_workset() - ################################################### - def load_data( self, batch_size = None) : + for bidx in range( iter_start, iter_end) : - years_data = self.years_data - - # ensure proper separation of different random samplers - delta = torch.randint( 0, 1000, (1,)).item() - self.rng.bit_generator.advance( delta) - - # select num_load random months and years - perms = np.concatenate( [self.rng.permutation( np.arange(len(years_data))) for i in range(64)]) - perms = perms[:self.num_load] - if self.month : - self.years_months = [ (years_data[iyear], self.month) for iyear in perms] - else : - # stratified sampling of month to ensure proper distribution, needs to be adapted for - # number of parallel workers not being divisible by 4 - # rank, ms = torch.distributed.get_rank() % 4, 3 - # perms_m = np.concatenate( [self.rng.permutation( np.arange( rank*ms+1, (rank+1)*ms+1)) - # for i in range(16)]) - perms_m = np.concatenate( [self.rng.permutation( np.arange( 1, 12+1)) for i in range(16)]) - self.years_months = [ ( years_data[iyear], perms_m[i]) for i,iyear in enumerate(perms)] - - # generate random permutations passed to the loaders for individual files - # to ensure consistent processing - self.shuffle() - - # perform actual loading of data - - for ds_field in self.datasets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) - - for ds_field in self.datasets_targets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) + sources, token_infos = [[] for _ in self.fields], [[] for _ in self.fields] + sources_infos, source_idxs = [], [] + + i_bidx = self.idxs_perm_t[bidx] + idxs_t = list(np.arange( i_bidx - n_size[0]*ts, i_bidx, ts, dtype=np.int64)) + data_tt_sfc = self.ds['data_sfc'].oindex[idxs_t] + data_tt = self.ds['data'].oindex[idxs_t] + for sidx in range(self.batch_size) : + + idx = self.idxs_perm[bidx*self.batch_size+sidx] + # slight asymetry with offset by res/2 is required to match desired token count + lat_ran = np.where(np.logical_and(lats>idx[0]-ns_2[1]-res[0]/2.,lats 360.) + il, ir = (idx[1]-ns_2[2]-res[1]/2., idx[1]+ns_2[2]) + if il < 0. : + lon_ran = np.concatenate( [np.where( lons > il+360)[0], np.where(lons < ir)[0]], 0) + elif ir > 360. : + lon_ran = np.concatenate( [np.where( lons > il)[0], np.where(lons < ir-360)[0]], 0) + else : + lon_ran = np.where(np.logical_and( lons > il, lons < ir))[0] + + sources_infos += [ [ self.ds['time'][ idxs_t ].astype(datetime), + self.lats[lat_ran], self.lons[lon_ran], self.res ] ] + + if self.with_source_idxs : + source_idxs += [ (idxs_t, lat_ran, lon_ran) ] + + # extract data + for ifield, field_info in enumerate(self.fields): + source_lvl, tok_info_lvl = [], [] + tok_size = field_info[4] + num_tokens = field_info[3] + corr_type = 'global' if len(field_info) <= 6 else field_info[6] + + for ilevel, vl in enumerate(field_info[2]): + if vl == 0 : #surface level + field_idx = self.ds.attrs['fields_sfc'].index( field_info[0]) + data_t = data_tt_sfc[ :, field_idx ] + else : + field_idx = self.ds.attrs['fields'].index( field_info[0]) + vl_idx = self.ds.attrs['levels'].index(vl) + data_t = data_tt[ :, field_idx, vl_idx ] + + source_data, tok_info = [], [] + # extract data, normalize and tokenize + cdata = data_t[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] + + normalizer = self.normalizers[ifield][ilevel] + if corr_type != 'global': + normalizer = normalizer[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]] + cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base) + + source_data = tokenize( torch.from_numpy( cdata), tok_size ) + # token_infos uses center of the token: *last* datetime and center in space + dates = self.ds['time'][ idxs_t ].astype(datetime) + cdates = dates[tok_size[0]-1::tok_size[0]] + # use -1 is to start days from 0 + dates = [(d.year, d.timetuple().tm_yday-1, d.hour) for d in cdates] + lats_sidx = self.lats[lat_ran][ tok_size[1]//2 :: tok_size[1] ] + lons_sidx = self.lons[lon_ran][ tok_size[2]//2 :: tok_size[2] ] + # tensor product for token_infos + tok_info += [[[[[ year, day, hour, vl, lat, lon, vl, self.res[0]] for lon in lons_sidx] + for lat in lats_sidx] + for (year, day, hour) in dates]] + + source_lvl += [ source_data ] + tok_info_lvl += [ torch.tensor(tok_info, dtype=torch.float32).flatten( 1, -2)] + sources[ifield] += [ torch.stack(source_lvl, 0) ] + token_infos[ifield] += [ torch.stack(tok_info_lvl, 0) ] + + # concatenate batches + sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources] + token_infos = [torch.stack(tis_field).transpose(1,0) for tis_field in token_infos] + sources = self.pre_batch( sources, token_infos ) + + tmidx_list = sources[-1] + weights_idx_list = [] + if self.compute_weights: + for ifield, field_info in enumerate(self.fields): + weights = [] + for ilevel, vl in enumerate(field_info[2]): + for ibatch in range(self.batch_size): + + lats_idx = source_idxs[ibatch][1] + lons_idx = source_idxs[ibatch][2] + + idx_base = tmidx_list[ifield][ilevel][ibatch] + idx_loc = idx_base - np.prod(num_tokens) * ibatch + + grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1 + grid = torch.from_numpy( np.array( np.broadcast_to( grid, + shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1)) + + grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2) + + lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()]) + + weights.append([get_weights(la) for la in lats_mskd_b]) + + weights_idx_list.append(weights) + sources = (*sources, weights_idx_list) + + # TODO: implement (only required when prediction target comes from different data stream) + targets, target_info = None, None + target_idxs = None + + yield ( sources, targets, (source_idxs, sources_infos), (target_idxs, target_info)) ################################################### def set_data( self, times_pos, batch_size = None) : @@ -347,55 +257,37 @@ def set_data( self, times_pos, batch_size = None) : - lon \in [0,360] - (year,month) pairs should be a limited number since all data for these is loaded ''' - - # extract required years and months - years_months_all = np.array( [ [it[0], it[1]] for it in times_pos ], dtype=np.int64) - self.years_months = list( zip( np.unique(years_months_all[:,0]), - np.unique( years_months_all[:,1] ))) - # generate all the data - self.idxs_perm = np.zeros( (len(times_pos), 4)) + self.idxs_perm = np.zeros( (len(times_pos), 2)) + self.idxs_perm_t = [] + self.num_samples = len(times_pos) for idx, item in enumerate( times_pos) : assert item[2] >= 1 and item[2] <= 31 assert item[3] >= 0 and item[3] < int(24 / self.time_sampling) assert item[4] >= -90. and item[4] <= 90. - # find year - for i_ym, ym in enumerate( self.years_months) : - if ym[0] == item[0] and ym[1] == item[1] : - break - - # last term: correct for window from last file that is loaded - it = (item[2] - 1) * (24./self.time_sampling) + item[3] - # it = item[2] * (24./self.time_sampling) + item[3] - idx_lat = item[4] - idx_lon = item[5] + tstamp = pd.to_datetime( f'{item[0]}-{item[1]}-{item[2]}-{item[3]}', format='%Y-%m-%d-%H') + + self.idxs_perm_t += [ np.where( self.times == tstamp)[0]+1 ] #The +1 assures that tsamp is included in the range # work with mathematical lat coordinates from here on - self.idxs_perm[idx] = np.array( [i_ym, it, 90. - idx_lat, idx_lon]) - - for ds_field in self.datasets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) - - for ds_field in self.datasets_targets : - for ds in ds_field : - ds.load_data( self.years_months, self.idxs_perm, batch_size) + self.idxs_perm[idx] = np.array( [90. - item[4], item[5]]) + + self.idxs_perm_t = np.array(self.idxs_perm_t).squeeze() ################################################### def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : ''' generate patch/token positions for global grid ''' - - token_overlap = torch.tensor( token_overlap).to(torch.int64) + token_overlap = np.array( token_overlap).astype(np.int64) # assumed that sanity checking that field data is consistent has been done ifield = 0 field = self.fields[ifield] res = self.res - side_len = torch.tensor( [field[3][1] * field[4][1] * res, field[3][2] * field[4][2] * res] ) - overlap = torch.tensor( [token_overlap[0]*field[4][1]*res, token_overlap[1]*field[4][2]*res] ) + side_len = np.array( [field[3][1] * field[4][1]*res[0], field[3][2] * field[4][2]*res[1]] ) + overlap = np.array([token_overlap[0]*field[4][1]*res[0],token_overlap[1]*field[4][2]*res[1]]) side_len_2 = side_len / 2. assert all( overlap <= side_len_2), 'token_overlap too large for #tokens, reduce if possible' @@ -422,72 +314,22 @@ def set_global( self, times, batch_size = None, token_overlap = [0, 0]) : lat -= side_len[0] - overlap[0] if lat - side_len_2[0] < 180. : num_tiles_lat += 1 - lat = 180. - side_len_2[0].item() + res + lat = 180. - side_len_2[0].item() + res[0] lon = side_len_2[1].item() - overlap[1].item()/2. while (lon - side_len_2[1]) < 360. : times_pos += [[*ctime, -lat + 90., np.mod(lon,360.) ]] lon += side_len[1].item() - overlap[1].item() # adjust batch size if necessary so that the evaluations split up across batches of equal size - batch_size = num_tiles_lon - + batch_size = len(times_pos) #num_tiles_lon + print( 'Number of batches per global forecast: {}'.format( num_tiles_lat) ) self.set_data( times_pos, batch_size) - ################################################### - def set_location( self, pos, years, months, num_t_samples_per_month, batch_size = None) : - ''' random time sampling for fixed location ''' - - times_pos = [] - for i_ym, ym in enumerate(itertools.product( years, months )) : - - # ensure a constant size of work load of data loader independent of the month length - # factor of 128 is a fudge parameter to ensure that mod-ing leads to sufficiently - # random wrap-around (with 1 instead of 128 there is clustering on the first days) - hours_in_day = int( 24 / self.time_sampling) - d_i_m = days_in_month( ym[0], ym[1]) - perms = self.rng.permutation( num_t_samples_per_month * d_i_m) - # ensure that days start at 1 - perms = np.mod( perms[ : num_t_samples_per_month], (d_i_m-1) ) + 1 - rhs = self.rng.integers(low=0, high=hours_in_day, size=num_t_samples_per_month ) - - for rh, perm in zip( rhs, perms) : - times_pos += [[ ym[0], ym[1], perm, rh, pos[0], pos[1]] ] - - # adjust batch size if necessary so that the evaluations split up across batches of equal size - while 0 != (len(times_pos) % batch_size) : - batch_size -= 1 - assert batch_size >= 1 - - self.set_data( times_pos, batch_size) - - ################################################### - def __iter__(self): - - iter_start, iter_end = self.worker_workset() - - for bidx in range( iter_start, iter_end) : - - sources = [] - for ds_field in self.datasets : - sources.append( [ds_level[bidx] for ds_level in ds_field]) - # perform batch pre-processing, e.g. BERT-type masking - if self.pre_batch : - sources = self.pre_batch( sources) - - targets = [] - for ds_field in self.datasets_targets : - targets.append( [ds_level[bidx] for ds_level in ds_field]) - # perform batch pre-processing, e.g. BERT-type masking - if self.pre_batch_targets : - targets = self.pre_batch_targets( targets) - - yield (sources,targets) - ################################################### def __len__(self): - return len(self.datasets[0][0]) + return self.num_samples // self.batch_size ################################################### def worker_workset( self) : @@ -496,17 +338,16 @@ def worker_workset( self) : if worker_info is None: iter_start = 0 - iter_end = len(self.datasets[0][0]) - + iter_end = self.num_samples + else: # split workload - temp = len(self.datasets[0][0]) - per_worker = int( np.floor( temp / float(worker_info.num_workers) ) ) + per_worker = len(self) // worker_info.num_workers worker_id = worker_info.id iter_start = int(worker_id * per_worker) iter_end = int(iter_start + per_worker) if worker_info.id+1 == worker_info.num_workers : - iter_end = int(temp) + iter_end = len(self) return iter_start, iter_end - + diff --git a/atmorep/datasets/normalizer.py b/atmorep/datasets/normalizer.py new file mode 100644 index 0000000..0487af2 --- /dev/null +++ b/atmorep/datasets/normalizer.py @@ -0,0 +1,82 @@ +#################################################################################################### +# +# Copyright (C) 2022 +# +#################################################################################################### +# +# project : atmorep +# +# author : atmorep collaboration +# +# description : +# +# license : +# +#################################################################################################### + +import code +import numpy as np +import xarray as xr +import atmorep.config.config as config + +###################################################### +# Normalize # +###################################################### + +def normalize( data, norm, dates, year_base = 1979) : + corr_data = np.array([norm[12*(dt.year-year_base) + dt.month-1] for dt in dates]) + mean, var = corr_data[:, 0], corr_data[:, 1] + if (var == 0.).all() : + print( f'Warning: var == 0') + assert False + if len(norm.shape) > 2 : #global norm + return normalize_local(data, mean, var) + else: + return normalize_global( data, mean, var) + +###################################################### +def normalize_local( data, mean, var) : + data = (data - mean) / var + return data + +###################################################### +def normalize_global( data, mean, var) : + for i in range( data.shape[0]) : + data[i] = (data[i] - mean[i]) / var[i] + return data + + +###################################################### +# Denormalize # +###################################################### +def denormalize(data, norm, dates, year_base = 1979) : + corr_data = np.array([norm[12*(dt.year-year_base) + dt.month-1] for dt in dates]) + mean, var = corr_data[:, 0], corr_data[:, 1] + if len(norm.shape) > 2 : + return denormalize_local(data, mean, var) + else: + return denormalize_global(data, mean, var) + +###################################################### + +def denormalize_local(data, mean, var) : + if len(data.shape) > 3: #ensemble + for i in range( data.shape[0]) : + data[i] = (data[i] * var) + mean + else: + data = (data * var) + mean + return data + +###################################################### + +def denormalize_global(data, mean, var) : + if len(data.shape) > 3: #ensemble + data = data.swapaxes(0,1) + for i in range( data.shape[0]) : + data[i] = ((data[i] * var[i]) + mean[i]) + data = data.swapaxes(0,1) + else: + for i in range( data.shape[0]) : + data[i] = (data[i] * var[i]) + mean[i] + + return data \ No newline at end of file diff --git a/atmorep/datasets/normalizer_global.py b/atmorep/datasets/normalizer_global.py deleted file mode 100644 index 2a844ce..0000000 --- a/atmorep/datasets/normalizer_global.py +++ /dev/null @@ -1,41 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import numpy as np - -import atmorep.config.config as config - -class NormalizerGlobal() : - - def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : - - # TODO: use path from config and pathlib.Path() - fname_base = '{}/normalization/{}/global_normalization_mean_var_{}_{}{}.bin' - - fn = field_info[0] - corr_fname = fname_base.format( str(config.path_data), fn, fn, level_type, vlevel) - self.corr_data = np.fromfile(corr_fname, dtype=np.float32).reshape( (-1, 4)) - - def normalize( self, year, month, data, coords = None) : - corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), - self.corr_data[:,1] == float(month))) , 2:].flatten() - return (data - corr_data_ym[0]) / corr_data_ym[1] - - def denormalize( self, year, month, data, coords = None) : - corr_data_ym = self.corr_data[ np.where(np.logical_and(self.corr_data[:,0] == float(year), - self.corr_data[:,1] == float(month))) , 2:].flatten() - return (data * corr_data_ym[1]) + corr_data_ym[0] - \ No newline at end of file diff --git a/atmorep/datasets/normalizer_local.py b/atmorep/datasets/normalizer_local.py deleted file mode 100644 index ffc8c37..0000000 --- a/atmorep/datasets/normalizer_local.py +++ /dev/null @@ -1,68 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import code -import numpy as np -import xarray as xr - -import atmorep.config.config as config - -class NormalizerLocal() : - - def __init__(self, field_info, vlevel, file_shape, data_type = 'era5', level_type = 'ml') : - - fname_base = '{}/normalization/{}/normalization_mean_var_{}_y{}_m{:02d}_{}{}.bin' - - self.corr_data = [ ] - for year in range( config.year_base, config.year_last+1) : - for month in range( 1, 12+1) : - corr_fname = fname_base.format( str(config.path_data), field_info[0], field_info[0], - year, month, level_type, vlevel) - x = np.fromfile( corr_fname, dtype=np.float32).reshape( (file_shape[1], file_shape[2], 2)) - x = xr.DataArray( x, [ ('lat', np.linspace( 0., 180., num=180*4+1, endpoint=True)), - ('lon', np.linspace( 0., 360., num=360*4, endpoint=False)), - ('data', ['mean', 'var']) ]) - self.corr_data.append( x) - - def normalize( self, year, month, data, coords) : - - corr_data_ym = self.corr_data[ (year - config.year_base) * 12 + month ] - mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values - var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values - - if len(data.shape) > 2 : - for i in range( data.shape[0]) : - data[i] = (data[i] - mean) / var - else : - data = (data - mean) / var - - return data - - def denormalize( self, year, month, data, coords) : - - corr_data_ym = self.corr_data[ (year - config.year_base) * 12 + month ] - mean = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='mean').values - var = corr_data_ym.sel( lat=coords[0], lon=coords[1], data='var').values - - if len(data.shape) > 2 : - for i in range( data.shape[0]) : - data[i] = (data[i] * var) + mean - else : - data = (data * var) + mean - - return data - - \ No newline at end of file diff --git a/atmorep/datasets/static_field.py b/atmorep/datasets/static_field.py deleted file mode 100644 index f636757..0000000 --- a/atmorep/datasets/static_field.py +++ /dev/null @@ -1,223 +0,0 @@ -#################################################################################################### -# -# Copyright (C) 2022 -# -#################################################################################################### -# -# project : atmorep -# -# author : atmorep collaboration -# -# description : -# -# license : -# -#################################################################################################### - -import torch -import pathlib -import numpy as np -import math -import os, sys -import time -import itertools - -import atmorep.utils.utils as utils -from atmorep.utils.utils import shape_to_str -from atmorep.utils.utils import days_until_month_in_year -from atmorep.utils.utils import days_in_month -from atmorep.utils.utils import tokenize - -from atmorep.datasets.data_loader import DataLoader - - -class StaticField() : - - ################################################### - def __init__( self, file_path, field_info, batch_size, data_type = 'reanalysis', - file_shape = (-1, 720, 1440), file_geo_range = [[90.,-90.], [0.,360.]], - num_tokens = [3, 9, 9], token_size = [1, 9, 9], - smoothing = 0, file_format = 'grib', corr_type = 'global') : - ''' - Data set for single dynamic field at a single vertical level - ''' - - self.field_info = field_info - self.file_path = file_path - self.file_shape = file_shape - self.file_format = file_format - self.smoothing = smoothing - self.corr_type = corr_type - - # # work internally with mathematical latitude coordinates in [0,180] - # self.is_global = np.abs(file_geo_range[0][0])==90. and file_geo_range[1][0]==0. \ - # and np.abs(file_geo_range[0][0])==90. and file_geo_range[1][1]==360. - # self.file_geo_range = [ -np.array(file_geo_range[0]) + 90. , file_geo_range[1] ] - # self.file_geo_range[0] = np.flip( self.file_geo_range[0]) \ - # if self.file_geo_range[0][0] > self.file_geo_range[0][1] else self.file_geo_range[0] - - # work internally with mathematical latitude coordinates in [0,180] - self.file_geo_range = [ -np.array(file_geo_range[0]) + 90. , np.array(file_geo_range[1]) ] - # enforce that georange is North to South - self.geo_range_flipped = False - if self.file_geo_range[0][0] > self.file_geo_range[0][1] : - self.file_geo_range[0] = np.flip( self.file_geo_range[0]) - self.geo_range_flipped = True - print( 'Flipped georange') - print( '{} :: geo_range : {}'.format( field_info[0], self.file_geo_range) ) - self.is_global = 0. == self.file_geo_range[0][0] and 0. == self.file_geo_range[1][0] \ - and 180. == self.file_geo_range[0][1] and 360. == self.file_geo_range[1][1] - print( '{} :: is_global : {}'.format( field_info[0], self.is_global) ) - - self.batch_size = batch_size - self.num_tokens = torch.tensor( num_tokens, dtype=torch.int) - rem1 = (num_tokens[1]*token_size[1]) % 2 - rem2 = (num_tokens[2]*token_size[2]) % 2 - t1 = num_tokens[1]*token_size[1] - t2 = num_tokens[2]*token_size[2] - self.grid_delta = [ [int((t1+rem1)/2), int(t1/2)], [int((t2+rem2)/2), int(t2/2)] ] - assert( num_tokens[1] < file_shape[1]) - assert( num_tokens[2] < file_shape[2]) - self.tok_size = token_size - #assert( file_shape[1] % token_size[1] == 0) - #assert( file_shape[2] % token_size[2] == 0) - - # resolution - # TODO: non-uniform resolution in latitude and longitude - self.res = (file_geo_range[1][1] - file_geo_range[1][0]) - self.res /= file_shape[2] if self.is_global else (file_shape[2]-1) - - self.data_field = None - - self.loader = DataLoader( self.file_path, self.file_shape, data_type, - file_format = self.file_format, - smoothing = self.smoothing ) - - ################################################### - def load_data( self, years_months, idxs_perm, batch_size = None) : - - self.idxs_perm = idxs_perm - loader = self.loader - - if batch_size : - self.batch_size = batch_size - - # load data - self.data_field = loader.get_static_field( self.field_info[0], [-1, -1]) - - # # corrections: - self.correction_field = loader.get_correction_static_field( self.field_info[0], self.corr_type ) - - mean = self.correction_field[0] - std = self.correction_field[1] - - self.data_field = (self.data_field - mean) / std - - if self.geo_range_flipped : - self.data_field = torch.flip( self.data_field, [0]) - - # # basics statistics - # print( 'INFO:: data stats {} : {} / {}'.format( self.field_info[0], - # self.data_field.mean(), - # self.data_field.std()) ) - - ################################################### - def set_data( self, date_pos ) : - ''' - date_pos = np.array( [ [year, month, day, hour, lat, lon], ...] ) - - lat \in [-90,90] = [90N, 90S] - - (year,month) pairs should be a limited number since all data for these is loaded - ''' - - # extract required years and months - years_months_all = np.array( [ [it[0], it[1]] for it in date_pos ], dtype=np.int64) - self.years_months = list( zip( np.unique(years_months_all[:,0]), - np.unique( years_months_all[:,1] ))) - - # load data and corrections - self.load_data() - - # generate all the data - self.idxs_perm = np.zeros( (date_pos.shape[0], 4), dtype=np.int64) - for idx, item in enumerate( date_pos) : - - assert item[2] >= 1 and item[2] <= 31 - assert item[3] >= 0 and item[3] < int(24 / self.time_sampling) - assert item[4] >= -90. and item[4] <= 90. - - # find year - for i_ym, ym in enumerate( self.years_months) : - if ym[0] == item[0] and ym[1] == item[1] : - break - - it = (item[2] - 1.) * 24. + item[3] + self.tok_size[0] - idx_lat = int( (item[4] + 90.) * 720. / 180.) - idx_lon = int( (item[5] % 360) * 1440. / 360.) - - self.idxs_perm[idx] = np.array( [i_ym, it, idx_lat, idx_lon], dtype=np.int64) - - ############################################### - def __getitem__( self, bidx) : - - tn = self.grid_delta - num_tokens = self.num_tokens - tok_size = self.tok_size - geor = self.file_geo_range - - idx = bidx * self.batch_size - - # physical fields - patch_s = [nt*ts for nt,ts in zip(self.num_tokens,self.tok_size)] - x = torch.zeros( self.batch_size, 1, patch_s[1], patch_s[2] ) - cids = torch.zeros( self.batch_size, num_tokens.prod(), 8) - - # 721 etc have grid points at the beginning and end which leads to incorrect results in places - file_shape = np.array(self.file_shape) - file_shape = file_shape-1 if not self.is_global else np.array(self.file_shape)-np.array([0,1,0]) - - # for all items in batch - for jj in range( self.batch_size) : - - # perform a deep copy to not overwrite cid for other fields - cid = np.array( self.idxs_perm[idx][1:]).copy() - - # map to grid coordinates (first map to normalized [0,1] coords and then to grid coords) - cid[2] = np.mod( cid[2], 360.) if self.is_global else cid[2] - assert cid[1] >= geor[0][0] and cid[1] <= geor[0][1], 'invalid latitude for geo_range' - cid[1] = ( (cid[1] - geor[0][0]) / (geor[0][1] - geor[0][0]) ) * file_shape[1] - cid[2] = ( ((cid[2]) - geor[1][0]) / (geor[1][1] - geor[1][0]) ) * file_shape[2] - assert cid[1] >= 0 and cid[1] < self.file_shape[1] - assert cid[2] >= 0 and cid[2] < self.file_shape[2] - - # alignment when parent field has different resolution than this field - cid = np.round( cid).astype( np.int64) - - # periodic boundary conditions around equator - ran_lon = np.array( list( range( cid[2]-tn[1][0], cid[2]+tn[1][1]))) - if self.is_global : - ran_lon = np.mod( ran_lon, self.file_shape[2]) - else : - # sanity check for indices for files with local window - # this should be controlled by georange_sampling for sampling - assert any( ran_lon >= 0) or any( ran_lon < self.file_shape[2]) - - ran_lat = np.array( list( range( cid[1]-tn[0][0], cid[1]+tn[0][1]))) - assert any( ran_lat >= 0) or any( ran_lat < self.file_shape[1]) - - # current data - x[jj,0] = np.take( np.take( self.data_field, ran_lat, 0), ran_lon, 1) - - # set per token information - lats = ran_lat[int(tok_size[1]/2)::tok_size[1]] * self.res + self.file_geo_range[0][0] - lons = ran_lon[int(tok_size[2]/2)::tok_size[2]] * self.res + self.file_geo_range[1][0] - stencil = torch.tensor(list(itertools.product(lats,lons))) - cids[jj,:,4:6] = stencil - cids[jj,:,7] = self.res - - idx += 1 - - return (x, cids) - - ################################################### - def __len__(self): - return int(self.idxs_perm.shape[0] / self.batch_size) diff --git a/atmorep/tests/__init__.py b/atmorep/tests/__init__.py new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/atmorep/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/atmorep/tests/conftest.py b/atmorep/tests/conftest.py new file mode 100644 index 0000000..62daba8 --- /dev/null +++ b/atmorep/tests/conftest.py @@ -0,0 +1,8 @@ +def pytest_addoption(parser): + parser.addoption("--field", action="store", help="field to run the test on") + parser.addoption("--model_id", action="store", help="wandb ID of the atmorep model") + parser.addoption("--epoch", action="store", help="field to run the test on", default = "0") + parser.addoption("--strategy", action="store", help="BERT or forecast") + + + diff --git a/atmorep/tests/test_utils.py b/atmorep/tests/test_utils.py new file mode 100644 index 0000000..60176ee --- /dev/null +++ b/atmorep/tests/test_utils.py @@ -0,0 +1,77 @@ +import numpy as np +import pandas as pd + +def era5_fname(): + return "/gpfs/scratch/ehpc03/data/{}/ml{}/era5_{}_y{}_m{}_ml{}.grib" + +def atmorep_pred(): + return "./results/id{}/results_id{}_epoch{}_pred.zarr" + +def atmorep_target(): + return "./results/id{}/results_id{}_epoch{}_target.zarr" + +def grib_index(field): + grib_idxs = {"velocity_u": "u", + "temperature": "t", + "total_precip": "tp", + "velocity_v": "v", + "velocity_z": "z", + "vorticity" : "vo", + "divergence" : "d", + "specific_humidity": "q"} + + return grib_idxs[field] + +################################################################## + +def get_BERT(atmorep, field, sample, level): + atmorep_sample = atmorep[f"{field}/sample={sample:05d}/ml={level:05d}"] + data = atmorep_sample.data[0,0] + datetime = pd.Timestamp(atmorep_sample.datetime[0,0]) + lats = atmorep_sample.lat[0] + lons = atmorep_sample.lon[0] + return data, datetime, lats, lons + +def get_forecast(atmorep, field, sample,level_idx): + atmorep_sample = atmorep[f"{field}/sample={sample:05d}"] + data = atmorep_sample.data[level_idx, 0] + datetime = pd.Timestamp(atmorep_sample.datetime[0]) + lats = atmorep_sample.lat + lons = atmorep_sample.lon + return data, datetime, lats, lons + +###################################### + +def check_lats(lats_pred, lats_target): + assert (lats_pred[:] == lats_target[:]).all(), "Mismatch between latitudes" + assert (lats_pred[:] <= 90.).all(), f"latitudes are between {np.amin(lats_pred)}- {np.amax(lats_pred)}" + assert (lats_pred[:] >= -90.).all(), f"latitudes are between {np.amin(lats_pred)}- {np.amax(lats_pred)}" + +def check_lons(lons_pred, lons_target): + assert (lons_pred[:] == lons_target[:]).all(), "Mismatch between longitudes" + assert (lons_pred[:] >= 0.).all(), "longitudes are between {np.amin(lons_pred)}- {np.amax(lons_pred)}" + assert (lons_pred[:] <= 360.).all(), "longitudes are between {np.amin(lons_pred)}- {np.amax(lons_pred)}" + +def check_datetimes(datetimes_pred, datetimes_target): + assert (datetimes_pred == datetimes_target), "Mismatch between datetimes" + +###################################### + +#calculate RMSE +def compute_RMSE(pred, target): + return np.sqrt(np.mean((pred-target)**2)) + + +def get_max_RMSE(field): + #TODO: optimize thresholds + values = {"temperature" : 3, + "velocity_u" : 0.2, #???? + "velocity_v": 0.2, #???? + "velocity_z": 0.2, #???? + "vorticity" : 0.2, #???? + "divergence": 0.2, #???? + "specific_humidity": 0.2, #???? + "total_precip": 1, #????? + } + + return values[field] \ No newline at end of file diff --git a/atmorep/tests/validation_test.py b/atmorep/tests/validation_test.py new file mode 100644 index 0000000..4dbea9a --- /dev/null +++ b/atmorep/tests/validation_test.py @@ -0,0 +1,126 @@ +import pytest +import zarr +import cfgrib +import xarray as xr +import numpy as np +import random as rnd +import warnings +import os + +from atmorep.tests.test_utils import * + +# run it with e.g. pytest -s atmorep/tests/validation_test.py --field temperature --model_id ztsut0mr --strategy BERT + +@pytest.fixture +def field(request): + return request.config.getoption("field") + +@pytest.fixture +def model_id(request): + return request.config.getoption("model_id") + +@pytest.fixture +def epoch(request): + request.config.getoption("epoch") + +@pytest.fixture(autouse = True) +def BERT(request): + strategy = request.config.getoption("strategy") + return (strategy == 'BERT' or strategy == 'temporal_interpolation') + +@pytest.fixture(autouse = True) +def strategy(request): + return request.config.getoption("strategy") + +#TODO: add test for global_forecast vs ERA5 + +def test_datetime(field, model_id, BERT, epoch = 0): + + """ + Check against ERA5 timestamps. + Loop over all levels individually. 50 random samples for each level. + """ + + store = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) + atmorep = zarr.group(store) + + nsamples = min(len(atmorep[field]), 50) + samples = rnd.sample(range(len(atmorep[field])), nsamples) + levels = [int(f.split("=")[1]) for f in atmorep[f"{field}/sample=00000"]] if BERT else atmorep[f"{field}/sample=00000"].ml[:] + + get_data = get_BERT if BERT else get_forecast + + for level in levels: + #TODO: make it more elegant + level_idx = level if BERT else np.where(levels == level)[0].tolist()[0] + + for s in samples: + data, datetime, lats, lons = get_data(atmorep, field, s, level_idx) + year, month = datetime.year, str(datetime.month).zfill(2) + + era5_path = era5_fname().format(field, level, field, year, month, level) + if not os.path.isfile(era5_path): + warnings.warn(UserWarning((f"Timestamp {datetime} not found in ERA5. Skipping"))) + continue + era5 = xr.open_dataset(era5_path, engine = "cfgrib")[grib_index(field)].sel(time = datetime, latitude = lats, longitude = lons) + + #assert (data[0] == era5.values[0]).all(), "Mismatch between ERA5 and AtmoRep Timestamps" + assert np.isclose(data[0], era5.values[0],rtol=1e-04, atol=1e-07).all(), "Mismatch between ERA5 and AtmoRep Timestamps" + +############################################################################# + +def test_coordinates(field, model_id, BERT, epoch = 0): + """ + Check that coordinates match between target and prediction. + Check also that latitude and longitudes are in geographical coordinates + 50 random samples. + """ + + store_t = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) + target = zarr.group(store_t) + + store_p = zarr.ZipStore(atmorep_pred().format(model_id, model_id, str(epoch).zfill(5))) + pred = zarr.group(store_p) + + nsamples = min(len(target[field]), 50) + samples = rnd.sample(range(len(target[field])), nsamples) + levels = [int(f.split("=")[1]) for f in target[f"{field}/sample=00000"]] if BERT else target[f"{field}/sample=00000"].ml[:] + + get_data = get_BERT if BERT else get_forecast + + for level in levels: + level_idx = level if BERT else np.where(levels == level)[0].tolist()[0] + for s in samples: + _, datetime_target, lats_target, lons_target = get_data(target,field, s, level_idx) + _, datetime_pred, lats_pred, lons_pred = get_data(pred, field, s, level_idx) + + check_lats(lats_pred, lats_target) + check_lons(lons_pred, lons_target) + check_datetimes(datetime_pred, datetime_target) + +######################################################################### + +def test_rmse(field, model_id, BERT, epoch = 0): + """ + Test that for each field the RMSE does not exceed a certain value. + 50 random samples. + """ + store_t = zarr.ZipStore(atmorep_target().format(model_id, model_id, str(epoch).zfill(5))) + target = zarr.group(store_t) + + store_p = zarr.ZipStore(atmorep_pred().format(model_id, model_id, str(epoch).zfill(5))) + pred = zarr.group(store_p) + + nsamples = min(len(target[field]), 50) + samples = rnd.sample(range(len(target[field])), nsamples) + levels = [int(f.split("=")[1]) for f in target[f"{field}/sample=00000"]] if BERT else target[f"{field}/sample=00000"].ml[:] + + get_data = get_BERT if BERT else get_forecast + + for level in levels: + level_idx = level if BERT else np.where(levels == level)[0].tolist()[0] + for s in samples: + sample_target, _, _, _ = get_data(target,field, s, level_idx) + sample_pred, _, _, _ = get_data(pred,field, s, level_idx) + + assert compute_RMSE(sample_target, sample_pred).mean() < get_max_RMSE(field) diff --git a/atmorep/training/bert.py b/atmorep/training/bert.py index 51ccdd2..a4d1ea1 100644 --- a/atmorep/training/bert.py +++ b/atmorep/training/bert.py @@ -19,12 +19,9 @@ from functools import partial import code -from atmorep.utils.utils import tokenize - #################################################################################################### -def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data) : +def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data, fields_infos) : - fields_tokens_masked_idx = [[] for _ in fields_data] fields_tokens_masked_idx_list = [[] for _ in fields_data] fields_targets = [[] for _ in fields_data] sources = [[] for _ in fields_data] @@ -35,66 +32,27 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data) if BERT_strategy == 'BERT' : bert_f = prepare_batch_BERT_field + elif BERT_strategy == 'global_forecast' : + bert_f = prepare_batch_BERT_forecast_field elif BERT_strategy == 'forecast' : bert_f = prepare_batch_BERT_forecast_field elif BERT_strategy == 'temporal_interpolation' : bert_f = prepare_batch_BERT_temporal_field - elif BERT_strategy == 'forecast_1shot' : - bert_f = prepare_batch_BERT_forecast_field_1shot - elif BERT_strategy == 'identity' : - bert_f = prepare_batch_BERT_identity_field - elif BERT_strategy == 'totalmask' : - bert_f = prepare_batch_BERT_totalmask_field else : assert False - # # advance randomly to avoid issues with parallel data loaders that naively duplicate rngs - # delta = torch.randint( 0, 1000, (1,)).item() - # [rng.bit_generator.advance( delta) for rng in rngs] - - if cf.BERT_window : - # window size has to be multiple of two due to the variable token sizes (the size is - # however currently restricted to differ by exactly a factor of two only) - size_t = int(rngs[0].integers( 2, fields[0][3][0]+1, 1)[0] / 2.) * 2 - size_lat = int(rngs[0].integers( 2, fields[0][3][1]+1, 1)[0] / 2.) * 2 - size_lon = int(rngs[0].integers( 2, fields[0][3][2]+1, 1)[0] / 2.) * 2 - rng_idx = 1 - for ifield, data_field in enumerate(fields_data) : - for ilevel, (field_data, token_info) in enumerate(data_field) : - - tok_size = fields[ifield][4] - field_data = tokenize( field_data, tok_size ) - field_data_shape = field_data.shape - - # cut neighborhood for current batch - if cf.BERT_window : - # adjust size based on token size so that one has a fixed size window in physical space - cur_size_t = int(size_t * fields[ifield][3][0] / fields[0][3][0]) - cur_size_lat = int(size_lat * fields[ifield][3][1] / fields[0][3][1]) - cur_size_lon = int(size_lon * fields[ifield][3][2] / fields[0][3][2]) - # define indices - idx_t_s = field_data.shape[1] - cur_size_t - idx_lat_s = field_data.shape[2] - cur_size_lat - idx_lon_s = field_data.shape[3] - cur_size_lon - # cut - field_data = field_data[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] - field_data = field_data.contiguous() - # for token info first recover space-time shape - token_info = token_info.reshape( list(field_data_shape[0:4]) + [token_info.shape[-1]]) - token_info = token_info[ :, idx_t_s:, idx_lat_s:, idx_lon_s:] - token_info = torch.flatten( token_info, 1, -2) - token_info = token_info.contiguous() - + for ifield, (field, infos) in enumerate(zip(fields_data, fields_infos)) : + for ilevel, (field_data, token_info) in enumerate(zip(field, infos)) : + # no masking for static fields or if masking rate = 0 if fields[ifield][1][0] > 0 and fields[ifield][5][0] > 0. : ret = bert_f( cf, ifield, field_data, token_info, rngs[rng_idx]) - (field_data, token_info, target, tokens_masked_idx, tokens_masked_idx_list) = ret + (field_data, token_info, target, tokens_masked_idx_list) = ret - if target != None : + if target is not None : fields_targets[ifield].append( target) - fields_tokens_masked_idx[ifield].append( tokens_masked_idx) fields_tokens_masked_idx_list[ifield].append( tokens_masked_idx_list) rng_idx += 1 @@ -109,8 +67,7 @@ def prepare_batch_BERT_multifield( cf, rngs, fields, BERT_strategy, fields_data) fields_targets[ifield] = torch.cat( fields_targets[ifield],0) \ if len(fields_targets[ifield]) > 0 else fields_targets[ifield] - return (sources, token_infos, fields_targets, fields_tokens_masked_idx, - fields_tokens_masked_idx_list) + return (sources, token_infos, fields_targets, fields_tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : @@ -126,7 +83,7 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) - + # select random token in the selected space-time cube to be masked/deleted BERT_frac = cf.fields[ifield][5][0] BERT_frac_mask = cf.fields[ifield][5][1] @@ -136,15 +93,15 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : token_size = cf.fields[ifield][4] batch_dim = source.shape[0] num_tokens = source.shape[1] - # + masking_ratios = rng.random( batch_dim) * BERT_frac # number of tokens masked per batch entry nums_masked = np.ceil( num_tokens * masking_ratios).astype(np.int64) tokens_masked_idx_list = [ torch.tensor(rng.permutation(num_tokens)[:nms]) for nms in nums_masked] # linear indices for masking - idx = torch.cat( [tokens_masked_idx_list[i] + num_tokens * i for i in range(batch_dim)] ) - tokens_masked_idx = idx + tokens_masked_idx_list = [tokens_masked_idx_list[i] + num_tokens * i for i in range(batch_dim)] + idx = torch.cat( tokens_masked_idx_list) # flatten along first two dimension to simplify linear indexing (which then requires an # easily computable row offset) @@ -153,7 +110,7 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # keep masked tokens for loss computation target = source[idx].clone() - + # climatological mean of normalized data global_mean = 0. * torch.mean(source, 0) global_std = torch.std(source, 0) @@ -192,16 +149,11 @@ def prepare_batch_BERT_field( cf, ifield, source, token_info, rng) : # unsqueeze(usq()) is required since channel dimension is expected temp = mr( mr( usq( source[ idx[idx_mr_cond] ].reshape( (-1,ts[0],ts[1],ts[2])), 1), mrs), ts) source[ idx[idx_mr_cond] ] = sq( fl( temp, -3, -1)) - # adjust resolution parameter in token_info - token_info_shape = token_info.shape - token_info = token_info.flatten( 0, 1) - token_info[ idx[idx_mr_cond] ][-1] *= (mrs[1] + mrs[2]) / 2. #TODO: anisotropic resolution - token_info = token_info.reshape( token_info_shape) # recover batch dimension which was flattend for easier indexing and also token dimensions source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx, tokens_masked_idx_list) + return (source, token_info, target, tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : @@ -210,15 +162,15 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : num_tokens = source.shape[-6:-3] num_tokens_space = num_tokens[1] * num_tokens[2] idxs = (num_tokens[0]-nt) * num_tokens_space + torch.arange(nt * num_tokens_space) - + # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) # linear indices for masking num_tokens = source.shape[1] - idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) - tokens_masked_idx = idx + tokens_masked_idx_list = [idxs + num_tokens * i for i in range( source.shape[0] )] + idx = torch.cat( tokens_masked_idx_list) source_shape = source.shape # flatten along first two dimension to simplify linear indexing (which then requires an @@ -235,84 +187,21 @@ def prepare_batch_BERT_forecast_field( cf, ifield, source, token_info, rng) : # recover batch dimension which was flattend for easier indexing source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx, idxs) + return (source, token_info, target, tokens_masked_idx_list) #################################################################################################### def prepare_batch_BERT_temporal_field( cf, ifield, source, token_info, rng) : num_tokens = source.shape[-6:-3] - num_tokens_space = num_tokens[1] * num_tokens[2] - idx_time_mask = int( np.floor(num_tokens[0] / 2.)) # TODO: masking of multiple time steps - idxs = idx_time_mask * num_tokens_space + torch.arange(num_tokens_space) - - # collapse token dimensions - source_shape0 = source.shape - source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) - - # linear indices for masking - num_tokens = source.shape[1] - idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) - tokens_masked_idx = idx - - source_shape = source.shape - # flatten along first two dimension to simplify linear indexing (which then requires an - # easily computable row offset) - source = torch.flatten( source, 0, 1) - - # keep masked tokens for loss computation - target = source[idx].clone() - - # masking - global_mean = 0. * torch.mean(source, 0) - source[ idx ] = global_mean - - # recover batch dimension which was flattend for easier indexing - source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - - return (source, token_info, target, tokens_masked_idx, idxs) - -#################################################################################################### -def prepare_batch_BERT_forecast_field_1shot( cf, ifield, source, token_info, rng) : - - nt = 1 # TODO: specify this in config - num_tokens = source.shape[-6:-3] - num_tokens_space = num_tokens[1] * num_tokens[2] - idxs = (num_tokens[0]-nt) * num_tokens_space + torch.arange(num_tokens_space) - - # collapse token dimensions - source_shape0 = source.shape - source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) - - # linear indices for masking - num_tokens = source.shape[1] - # mask only every second neighborhood: 1 shot setting - idx = torch.cat( [idxs + num_tokens * i for i in range( 1, source.shape[0], 2 )] ) - tokens_masked_idx = idx - - source_shape = source.shape - # flatten along first two dimension to simplify linear indexing (which then requires an - # easily computable row offset) - source = torch.flatten( source, 0, 1) - - # keep masked tokens for loss computation - target = source[idx].clone() - - # masking - global_mean = 0. * torch.mean(source, 0) - source[ idx ] = global_mean - - # recover batch dimension which was flattend for easier indexing - source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - - return (source, token_info, target, tokens_masked_idx, idxs) - -#################################################################################################### -def prepare_batch_BERT_totalmask_field( cf, ifield, source, token_info, rng) : - - num_tokens = source.shape[-6:-3] - num_tokens_space = num_tokens[1] * num_tokens[2] - idxs = torch.arange(num_tokens[0] * num_tokens_space) - + num_tokens_space = num_tokens[1] * num_tokens[2] + + #backward compatibility: mask only middle token + if not hasattr( cf, 'idx_time_mask'): + idx_time_mask = int( np.floor(num_tokens[0] / 2.)) + idxs = idx_time_mask * num_tokens_space + torch.arange(num_tokens_space) + else: #list of idx_time_mask + idxs = torch.concat([i*num_tokens_space + torch.arange(num_tokens_space) for i in cf.idx_time_mask]) + # collapse token dimensions source_shape0 = source.shape source = torch.flatten( torch.flatten( source, 1, 3), 2, 4) @@ -320,8 +209,7 @@ def prepare_batch_BERT_totalmask_field( cf, ifield, source, token_info, rng) : # linear indices for masking num_tokens = source.shape[1] idx = torch.cat( [idxs + num_tokens * i for i in range( source.shape[0] )] ) - tokens_masked_idx = idx - + tokens_masked_idx_list = [idxs + num_tokens * i for i in range( source.shape[0] )] source_shape = source.shape # flatten along first two dimension to simplify linear indexing (which then requires an # easily computable row offset) @@ -337,9 +225,4 @@ def prepare_batch_BERT_totalmask_field( cf, ifield, source, token_info, rng) : # recover batch dimension which was flattend for easier indexing source = torch.reshape( torch.reshape( source, source_shape), source_shape0) - return (source, token_info, target, tokens_masked_idx) - -#################################################################################################### -def prepare_batch_BERT_identity_field( cf, ifield, source, token_info, rng) : - - return (source, token_info, None, None, None) + return (source, token_info, target, tokens_masked_idx_list) diff --git a/atmorep/transformer/mlp.py b/atmorep/transformer/mlp.py new file mode 100644 index 0000000..5908650 --- /dev/null +++ b/atmorep/transformer/mlp.py @@ -0,0 +1,67 @@ + +import torch + +from atmorep.utils.utils import identity +from atmorep.transformer.transformer_base import checkpoint_wrapper + +#################################################################################################### +class MLP(torch.nn.Module): + + def __init__(self, dim_embed, num_layers = 2, with_lnorm = True, dim_embed_out = None, + nonlin = torch.nn.GELU(), dim_internal_factor = 2, dropout_rate = 0., + grad_checkpointing = False, with_residual = True) : + """ + Multi-layer perceptron + + dim_embed : embedding dimension + num_layers : number of layers + nonlin : nonlinearity + dim_internal_factor : factor for number of hidden dimension relative to input / output + """ + super(MLP, self).__init__() + + if not dim_embed_out : + dim_embed_out = dim_embed + + self.with_residual = with_residual + + dim_internal = int( dim_embed * dim_internal_factor) + if with_lnorm : + self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) + else : + self.lnorm = torch.nn.Identity() + + self.blocks = torch.nn.ModuleList() + self.blocks.append( torch.nn.Linear( dim_embed, dim_internal)) + self.blocks.append( nonlin) + self.blocks.append( torch.nn.Dropout( p = dropout_rate)) + + for _ in range( num_layers-2) : + self.blocks.append( torch.nn.Linear( dim_internal, dim_internal)) + self.blocks.append( nonlin) + self.blocks.append( torch.nn.Dropout( p = dropout_rate)) + + self.blocks.append( torch.nn.Linear( dim_internal, dim_embed_out)) + self.blocks.append( nonlin) + + if dim_embed == dim_embed_out : + self.proj_residual = torch.nn.Identity() + else : + self.proj_residual = torch.nn.Linear( dim_embed, dim_embed_out) + + self.checkpoint = identity + if grad_checkpointing : + self.checkpoint = checkpoint_wrapper + + def forward( self, x, y = None) : + + x_in = x + x = self.lnorm( x) + + for block in self.blocks: + x = self.checkpoint( block, x) + + if self.with_residual : + x += x_in + + return x \ No newline at end of file diff --git a/atmorep/transformer/transformer.py b/atmorep/transformer/transformer.py index 5793ae7..c3f5ee9 100644 --- a/atmorep/transformer/transformer.py +++ b/atmorep/transformer/transformer.py @@ -15,12 +15,12 @@ #################################################################################################### import torch -import numpy as np -import math -from atmorep.transformer.transformer_base import MLP, prepare_token +from atmorep.transformer.mlp import MLP +from atmorep.transformer.transformer_base import prepare_token from atmorep.transformer.transformer_attention import MultiSelfAttentionHead + class Transformer(torch.nn.Module) : def __init__(self, num_layers, dim_input, dim_embed = 2048, num_heads = 8, num_mlp_layers = 2, diff --git a/atmorep/transformer/transformer_attention.py b/atmorep/transformer/transformer_attention.py index 2143365..7aa99a3 100644 --- a/atmorep/transformer/transformer_attention.py +++ b/atmorep/transformer/transformer_attention.py @@ -15,12 +15,8 @@ #################################################################################################### import torch -import numpy as np -import math from enum import Enum -import code -from atmorep.transformer.axial_attention import AxialAttention from atmorep.utils.utils import identity from atmorep.transformer.transformer_base import checkpoint_wrapper @@ -31,130 +27,68 @@ class CouplingAttentionMode( Enum) : kv_coupling = 2 -#################################################################################################### - -class AttentionHead(torch.nn.Module): - - def __init__(self, proj_dims, proj_dims_qs = -1, with_qk_lnorm = False, with_attention=False) : - '''Attention head''' - - super(AttentionHead, self).__init__() - - if proj_dims_qs == -1 : - proj_dims_qs = proj_dims[0] - - self.proj_qs = torch.nn.Linear( proj_dims_qs, proj_dims[1], bias = False) - self.proj_ks = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False) - self.proj_vs = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False) - - self.softmax = torch.nn.Softmax(dim=-1) - - if with_qk_lnorm : - self.lnorm_qs = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False) - self.lnorm_ks = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False) - else : - self.lnorm_qs = torch.nn.Identity() - self.lnorm_ks = torch.nn.Identity() - - self.forward = self.forward_attention if with_attention else self.forward_evaluate - - def attention( self, qs, ks) : - '''Compute attention''' - return torch.matmul( qs, torch.transpose( ks, -2, -1)) - - def forward_evaluate( self, xs_q, xs_k_v = None) : - '''Evaluate attention head''' - - xs_k_v = xs_q if None == xs_k_v else xs_k_v - - out_shape = xs_q.shape - qs = self.lnorm_qs( self.proj_qs( torch.flatten( xs_q, 1, -2) )) - ks = self.lnorm_ks( self.proj_ks( torch.flatten( xs_k_v, 1, -2) )) - vs = self.proj_vs( torch.flatten( xs_k_v, 1, -2) ) - # normalization increases interpretability since otherwise the scaling of the values - # interacts with the attention values - # torch.nn.functional.normalize( vs, dim=-1) - - scaling = 1. / torch.sqrt( torch.tensor(qs.shape[2])) - vsp = torch.matmul( self.softmax( scaling * self.attention( qs, ks)), vs) - return (vsp.reshape( [-1] + list(out_shape[1:-1]) + [vsp.shape[-1]]), None) - - def forward_attention( self, xs_q, xs_k_v = None) : - '''Evaluate attention head and also return attention''' - - xs_k_v = xs_q if None == xs_k_v else xs_k_v - - out_shape = xs_q.shape - kv_shape = xs_k_v.shape - qs = self.lnorm_qs( self.proj_qs( torch.flatten( xs_q, 1, -2) )) - ks = self.lnorm_ks( self.proj_ks( torch.flatten( xs_k_v, 1, -2) )) - vs = self.proj_vs( torch.flatten( xs_k_v, 1, -2) ) - # normalization increases interpretability since otherwise the scaling of the values - # interacts with the attention values - # torch.nn.functional.normalize( vs, dim=-1) - - scaling = 1. / torch.sqrt( torch.tensor(qs.shape[2])) - att = self.attention( qs, ks) - vsp = torch.matmul( self.softmax( scaling * att), vs) - return ( vsp.reshape( [-1] + list(out_shape[1:-1]) + [vsp.shape[-1]]), - att.reshape( [-1] + list(out_shape[1:-1]) + list(kv_shape[1:-1])).detach().cpu() ) - -#################################################################################################### - class MultiSelfAttentionHead(torch.nn.Module): - def __init__(self, dim_embed, num_heads, dropout_rate = 0., att_type = 'dense', - with_qk_lnorm = False, grad_checkpointing = False, with_attention = False ) : + ######################################### + def __init__(self, dim_embed, num_heads, dropout_rate=0., with_qk_lnorm=True, with_flash=True) : super(MultiSelfAttentionHead, self).__init__() + self.num_heads = num_heads + self.with_flash = with_flash + assert 0 == dim_embed % num_heads self.dim_head_proj = int(dim_embed / num_heads) - self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) + self.proj_heads = torch.nn.Linear( dim_embed, num_heads*3*self.dim_head_proj, bias = False) + self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) + self.dropout = torch.nn.Dropout( p=dropout_rate) if dropout_rate > 0. else torch.nn.Identity() - self.heads = torch.nn.ModuleList() - if 'dense' == att_type : - for n in range( num_heads) : - self.heads.append( AttentionHead( [dim_embed, self.dim_head_proj], - with_qk_lnorm= with_qk_lnorm, with_attention=with_attention)) - elif 'axial' in att_type : - self.heads.append( AxialAttention( dim = dim_embed, dim_index = -1, heads = num_heads, - num_dimensions = 3) ) + lnorm = torch.nn.LayerNorm if with_qk_lnorm else torch.nn.Identity + self.ln_q = lnorm( self.dim_head_proj, elementwise_affine=False) + self.ln_k = lnorm( self.dim_head_proj, elementwise_affine=False) + + # with_flash = False + if with_flash : + self.att = torch.nn.functional.scaled_dot_product_attention else : - assert False, 'Unsuppored attention type.' - - # proj_out is done is axial attention head so do not repeat it - self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) \ - if att_type == 'dense' else torch.nn.Identity() - self.dropout = torch.nn.Dropout( p=dropout_rate) - - self.checkpoint = identity - if grad_checkpointing : - self.checkpoint = checkpoint_wrapper + self.att = self.attention + self.softmax = torch.nn.Softmax(dim=-1) - def forward( self, x, y = None) : + ######################################### + def forward( self, x) : + split, tr = torch.tensor_split, torch.transpose + x_in = x x = self.lnorm( x) - outs, atts = [], [] - for head in self.heads : - y, att = self.checkpoint( head, x) - outs.append( y) - atts.append( y) - outs = torch.cat( outs, -1) - - outs = self.dropout( self.checkpoint( self.proj_out, outs) ) + # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention + s = [ *x.shape[:-1], self.num_heads, -1] + qs, ks, vs = split( self.proj_heads( x).reshape(s).transpose( 2, 1), 3, dim=-1) + qs, ks = self.ln_q( qs), self.ln_k( ks) + + # correct ordering of tensors with seq dimension second but last is critical + with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : + outs = self.att( qs, ks, vs).transpose( 2, 1) + + return x_in + self.dropout( self.proj_out( outs.flatten( -2, -1)) ) - return x_in + outs, atts + ######################################### + def attention( self, q, k, v) : + scaling = 1. / torch.sqrt( torch.tensor(q.shape[-1])) + return torch.matmul( self.softmax( scaling * self.score( q, k)), v) + + ######################################### + def score( self, q, k) : + return torch.matmul( q, torch.transpose( k, -2, -1)) #################################################################################################### class MultiCrossAttentionHead(torch.nn.Module): def __init__(self, dim_embed, num_heads, num_heads_other, dropout_rate = 0., with_qk_lnorm =False, - grad_checkpointing = False, with_attention=False): + grad_checkpointing = False, with_attention=False, with_flash=True): super(MultiCrossAttentionHead, self).__init__() self.num_heads = num_heads @@ -169,22 +103,24 @@ def __init__(self, dim_embed, num_heads, num_heads_other, dropout_rate = 0., wit else : self.lnorm_other = torch.nn.Identity() - # self attention heads - self.heads = torch.nn.ModuleList() - for n in range( num_heads) : - self.heads.append( AttentionHead( [dim_embed, self.dim_head_proj], - with_qk_lnorm = with_qk_lnorm, with_attention=with_attention)) + self.proj_heads = torch.nn.Linear( dim_embed, num_heads*3*self.dim_head_proj, bias = False) + + self.proj_heads_o_q = torch.nn.Linear(dim_embed, num_heads_other*self.dim_head_proj, bias=False) + self.proj_heads_o_kv= torch.nn.Linear(dim_embed,num_heads_other*2*self.dim_head_proj,bias=False) - # cross attention heads - self.heads_other = torch.nn.ModuleList() - for n in range( num_heads_other) : - self.heads_other.append( AttentionHead( [dim_embed, self.dim_head_proj], - with_qk_lnorm = with_qk_lnorm, with_attention=with_attention)) + self.ln_q = torch.nn.LayerNorm( self.dim_head_proj) + self.ln_k = torch.nn.LayerNorm( self.dim_head_proj) # proj_out is done is axial attention head so do not repeat it self.proj_out = torch.nn.Linear( dim_embed, dim_embed, bias = False) self.dropout = torch.nn.Dropout( p=dropout_rate) + if with_flash : + self.att = torch.nn.functional.scaled_dot_product_attention + else : + self.att = self.attention + self.softmax = torch.nn.Softmax(dim=-1) + self.checkpoint = identity if grad_checkpointing : self.checkpoint = checkpoint_wrapper @@ -192,27 +128,31 @@ def __init__(self, dim_embed, num_heads, num_heads_other, dropout_rate = 0., wit def forward( self, x, x_other) : x_in = x - x = self.lnorm( x) - x_other = self.lnorm_other( x_other) - - # output tensor where output of heads is linearly concatenated - outs, atts = [], [] - - # self attention - for head in self.heads : - y, att = self.checkpoint( head, x) - outs.append( y) - atts.append( att) - - # cross attention - for head in self.heads_other : - y, att = self.checkpoint( head, x, x_other) - outs.append( y) - atts.append( att) + x, x_other = self.lnorm( x), self.lnorm_other( x_other) - outs = torch.cat( outs, -1) + # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention + x = x.flatten( 1, -2) + s = [ *x.shape[:-1], self.num_heads, -1] + qs, ks, vs = torch.tensor_split( self.proj_heads( x).reshape(s).transpose( 2, 1), 3, dim=-1) + qs, ks = self.ln_q( qs), self.ln_k( ks) + + s = [ *x.shape[:-1], self.num_heads_other, -1] + qs_o = self.proj_heads_o_q( x).reshape(s).transpose( 2, 1) + x_o = x_other.flatten( 1, -2) + s = [ *x_o.shape[:-1], self.num_heads_other, -1] + ks_o, vs_o = torch.tensor_split( self.proj_heads_o_kv(x_o).reshape(s).transpose( 2, 1),2,dim=-1) + qs_o, ks_o = self.ln_q( qs_o), self.ln_k( ks_o) + + # correct ordering of tensors with seq dimension second but last is critical + with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : + s = list(x_in.shape) + s[-1] = -1 + outs_self = self.att( qs, ks, vs).transpose( 2, 1).flatten( -2, -1).reshape(s) + outs_other = self.att( qs_o, ks_o, vs_o).transpose( 2, 1).flatten( -2, -1).reshape(s) + outs = torch.cat( [outs_self, outs_other], -1) + outs = self.dropout( self.checkpoint( self.proj_out, outs) ) - + atts = [] return x_in + outs, atts #################################################################################################### @@ -220,44 +160,57 @@ def forward( self, x, x_other) : class MultiInterAttentionHead(torch.nn.Module): ##################################### - def __init__( self, num_heads_self, num_heads_coupling, dims_embed, with_lnorm = True, - dropout_rate = 0., with_qk_lnorm = False, grad_checkpointing = False, - with_attention=False) : + def __init__( self, num_heads_self, num_fields_other, num_heads_coupling_per_field, dims_embed, + with_lnorm = True, dropout_rate = 0., with_qk_lnorm = False, + grad_checkpointing = False, with_attention=False, with_flash=True) : '''Multi-head attention with multiple interacting fields coupled through attention.''' super(MultiInterAttentionHead, self).__init__() + self.num_heads_self = num_heads_self + self.num_heads_coupling_per_field = num_heads_coupling_per_field self.num_fields = len(dims_embed) - # self.coupling_mode = coupling_mode - # assert 0 == (dims_embed[0] % (num_heads_self + num_heads_coupling)) - # self.dim_head_proj = int(dims_embed[0] / (num_heads_self + num_heads_coupling)) self.dim_head_proj = int(dims_embed[0] / num_heads_self) # layer norms for all fields self.lnorms = torch.nn.ModuleList() + ln = torch.nn.LayerNorm if with_lnorm else torch.nn.Identity for ifield in range( self.num_fields) : - if with_lnorm : - self.lnorms.append( torch.nn.LayerNorm( dims_embed[ifield], elementwise_affine=False)) - else : - self.lnorms.append( torch.nn.Identity()) - - # self attention heads - self.heads_self = torch.nn.ModuleList() - for n in range( num_heads_self) : - self.heads_self.append( AttentionHead( [dims_embed[0], self.dim_head_proj], - dims_embed[0], with_qk_lnorm, with_attention=with_attention )) + self.lnorms.append( ln( dims_embed[ifield], elementwise_affine=False)) + + # self-attention + + nnc = num_fields_other * num_heads_coupling_per_field + self.proj_out = torch.nn.Linear( self.dim_head_proj * (num_heads_self + nnc), + dims_embed[0], bias = False) + self.dropout = torch.nn.Dropout( p=dropout_rate) if dropout_rate > 0. else torch.nn.Identity() + + nhs = num_heads_self + self.proj_heads = torch.nn.Linear( dims_embed[0], nhs*3*self.dim_head_proj, bias = False) - # coupling attention heads - self.heads_coupling = torch.nn.ModuleList() - for ifield in range( num_heads_coupling) : - arg1 = [dims_embed[ifield+1], self.dim_head_proj] - self.heads_coupling.append( AttentionHead( arg1, dims_embed[0], with_qk_lnorm, - with_attention=with_attention )) - - # self.proj_out = torch.nn.Linear( dims_embed[0], dims_embed[0], bias = False) - self.proj_out = torch.nn.Linear( self.dim_head_proj * (num_heads_self + num_heads_coupling), dims_embed[0], bias = False) - self.dropout = torch.nn.Dropout( p=dropout_rate) + # cross-attention + + nhc_dim = num_heads_coupling_per_field * self.dim_head_proj + self.proj_heads_other = torch.nn.ModuleList() + # queries from primary source/target field + self.proj_heads_other.append( torch.nn.Linear( dims_embed[0], nhc_dim*num_fields_other, + bias=False)) + # keys, values for other fields + for i in range(num_fields_other) : + self.proj_heads_other.append( torch.nn.Linear( dims_embed[i+1], 2*nhc_dim, bias=False)) + + ln = torch.nn.LayerNorm if with_qk_lnorm else torch.nn.Identity + self.ln_qk = (ln( self.dim_head_proj, elementwise_affine=False), + ln( self.dim_head_proj, elementwise_affine=False)) + nfo = num_fields_other + self.ln_k_other = [ln(self.dim_head_proj,elementwise_affine=False) for _ in range(nfo)] + + if with_flash : + self.att = torch.nn.functional.scaled_dot_product_attention + else : + self.att = self.attention + self.softmax = torch.nn.Softmax(dim=-1) self.checkpoint = identity if grad_checkpointing : @@ -266,32 +219,60 @@ def __init__( self, num_heads_self, num_heads_coupling, dims_embed, with_lnorm = ##################################### def forward( self, *args) : '''Evaluate block''' - - x_in = args[0] + + x_in, atts = args[0], [] # layer norm for each field fields_lnormed = [] for ifield, field in enumerate( args) : fields_lnormed.append( self.lnorms[ifield](field) ) + + # project onto heads and q,k,v and ensure these are 4D tensors as required for flash attention + # collapse three space and time dimensions for dense space-time attention + #proj_heads: torch.Size([16, 3, 128, 2048]) + field_proj = self.proj_heads( fields_lnormed[0].flatten(1,-2)) + s = [ *field_proj.shape[:-1], self.num_heads_self, -1 ] + qs, ks, vs = torch.tensor_split( field_proj.reshape(s).transpose(-3,-2), 3, dim=-1) + #breakpoint() + qs, ks = self.ln_qk[0]( qs), self.ln_qk[1]( ks) + if len(fields_lnormed) > 1 : + + field_proj = self.proj_heads_other[0]( fields_lnormed[0].flatten(1,-2)) + s = [ *field_proj.shape[:-1], len(fields_lnormed)-1, self.num_heads_coupling_per_field, -1 ] + qs_other = field_proj.reshape(s).permute( [-3, 0, -2, 1, -1]) - # output tensor where output of heads is linearly concatenated - outs, atts = [], [] - - # self attention - for head in self.heads_self : - y, att = self.checkpoint( head, fields_lnormed[0], fields_lnormed[0]) - outs.append( y) - atts.append( att) + ofields_projs = [] + for i,f in enumerate(fields_lnormed[1:]) : + f_proj = self.proj_heads_other[i+1](f.flatten(1,-2)) + s = [ *f_proj.shape[:-1], self.num_heads_coupling_per_field, -1 ] + ks_o, vs_o = torch.tensor_split( f_proj.reshape(s).transpose(-3,-2), 2, dim=-1) + ofields_projs += [ (self.ln_k_other[i]( ks_o), vs_o) ] + + # correct ordering of tensors with seq dimension second but last is critical + with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) : + + # self-attention + s = list(fields_lnormed[0].shape) + outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s) - # inter attention - for ifield, head in enumerate(self.heads_coupling) : - y, att = self.checkpoint( head, fields_lnormed[0], fields_lnormed[ifield+1]) - outs.append( y) - atts.append( att) + # cross-attention + if len(fields_lnormed) > 1 : + s[-1] = -1 + outs_other = [self.att( q, k, v).transpose( -3, -2).flatten( -2, -1).reshape(s) + for (q,(k,v)) in zip(qs_other,ofields_projs)] + outs = torch.cat( [outs, *outs_other], -1) - outs = torch.cat( outs, -1) - outs = self.dropout( self.checkpoint( self.proj_out, outs) ) + # code.interact( local=locals()) + outs = self.dropout( self.proj_out( outs)) return x_in + outs, atts - + ######################################### + def attention( self, q, k, v) : + scaling = 1. / torch.sqrt( torch.tensor(q.shape[-1])) + return torch.matmul( self.softmax( scaling * self.score( q, k)), v) + + ######################################### + def score( self, q, k) : + # code.interact( local=locals()) + return torch.matmul( q, torch.transpose( k, -2, -1)) diff --git a/atmorep/transformer/transformer_base.py b/atmorep/transformer/transformer_base.py index 42f9b30..71c59d0 100644 --- a/atmorep/transformer/transformer_base.py +++ b/atmorep/transformer/transformer_base.py @@ -16,10 +16,6 @@ import torch import numpy as np -import math -import code - -from atmorep.utils.utils import identity #################################################################################################### @@ -95,68 +91,6 @@ def positional_encoding_harmonic( x, num_levels, num_tokens, with_cls = False) : return x -#################################################################################################### -class MLP(torch.nn.Module): - - def __init__(self, dim_embed, num_layers = 2, with_lnorm = True, dim_embed_out = None, - nonlin = torch.nn.GELU(), dim_internal_factor = 2, dropout_rate = 0., - grad_checkpointing = False, with_residual = True) : - """ - Multi-layer perceptron - - dim_embed : embedding dimension - num_layers : number of layers - nonlin : nonlinearity - dim_internal_factor : factor for number of hidden dimension relative to input / output - """ - super(MLP, self).__init__() - - if not dim_embed_out : - dim_embed_out = dim_embed - - self.with_residual = with_residual - - dim_internal = int( dim_embed * dim_internal_factor) - if with_lnorm : - self.lnorm = torch.nn.LayerNorm( dim_embed, elementwise_affine=False) - else : - self.lnorm = torch.nn.Identity() - - self.blocks = torch.nn.ModuleList() - self.blocks.append( torch.nn.Linear( dim_embed, dim_internal)) - self.blocks.append( nonlin) - self.blocks.append( torch.nn.Dropout( p = dropout_rate)) - - for _ in range( num_layers-2) : - self.blocks.append( torch.nn.Linear( dim_internal, dim_internal)) - self.blocks.append( nonlin) - self.blocks.append( torch.nn.Dropout( p = dropout_rate)) - - self.blocks.append( torch.nn.Linear( dim_internal, dim_embed_out)) - self.blocks.append( nonlin) - - if dim_embed == dim_embed_out : - self.proj_residual = torch.nn.Identity() - else : - self.proj_residual = torch.nn.Linear( dim_embed, dim_embed_out) - - self.checkpoint = identity - if grad_checkpointing : - self.checkpoint = checkpoint_wrapper - - def forward( self, x, y = None) : - - x_in = x - x = self.lnorm( x) - - for block in self.blocks: - x = self.checkpoint( block, x) - - if self.with_residual : - x += x_in - - return x - #################################################################################################### def prepare_token_info( cf, token_info) : @@ -173,7 +107,7 @@ def prepare_token_info( cf, token_info) : return token_info #################################################################################################### -def prepare_token( xin, embed, embed_token_info, with_cls = True) : +def prepare_token( xin, embed, embed_token_info) : (token_seq, token_info) = xin num_tokens = token_seq.shape[-6:-3] @@ -187,18 +121,8 @@ def prepare_token( xin, embed, embed_token_info, with_cls = True) : # token_info = prepare_token_info( cf, token_info) token_info = token_info.reshape([-1] + list(token_seq_embed.shape[1:-1])+[token_info.shape[-1]]) token_seq_embed = torch.cat( [token_seq_embed, token_info], -1) - - # class token - if with_cls : - # initialize to zero (mean of data) - tts = token_seq_embed.shape - cls_token = torch.zeros( (tts[0], 1, tts[2]), device=token_seq_embed.device) # add positional encoding token_seq_embed = positional_encoding_harmonic( token_seq_embed, num_levels, num_tokens) - # add class token after positional encoding - if with_cls : - token_seq_embed = torch.cat( [ cls_token, token_seq_embed ], 1) - return token_seq_embed diff --git a/atmorep/transformer/transformer_decoder.py b/atmorep/transformer/transformer_decoder.py index 11c622b..31720ba 100644 --- a/atmorep/transformer/transformer_decoder.py +++ b/atmorep/transformer/transformer_decoder.py @@ -15,11 +15,8 @@ #################################################################################################### import torch -import torch.utils.checkpoint as checkpoint -import numpy as np -import math -from atmorep.transformer.transformer_base import MLP, prepare_token +from atmorep.transformer.mlp import MLP from atmorep.transformer.transformer_attention import MultiSelfAttentionHead, MultiCrossAttentionHead from atmorep.transformer.axial_attention import MultiFieldAxialAttention from atmorep.utils.utils import identity diff --git a/atmorep/transformer/transformer_encoder.py b/atmorep/transformer/transformer_encoder.py index 36581f9..f88c7f6 100644 --- a/atmorep/transformer/transformer_encoder.py +++ b/atmorep/transformer/transformer_encoder.py @@ -16,9 +16,8 @@ import torch import numpy as np -import math -from atmorep.transformer.transformer_base import MLP +from atmorep.transformer.mlp import MLP from atmorep.transformer.transformer_attention import MultiInterAttentionHead from atmorep.transformer.axial_attention import MultiFieldAxialAttention @@ -57,7 +56,7 @@ def create( self) : self.mlps = torch.nn.ModuleList() for il in range( cf.encoder_num_layers) : - nhc = cf.coupling_num_heads_per_field * len( field_info[1][2]) + nhc = cf.coupling_num_heads_per_field # * len( field_info[1][2]) # nhs = cf.encoder_num_heads - nhc nhs = cf.encoder_num_heads # number of tokens @@ -78,8 +77,9 @@ def create( self) : # attention heads if 'dense' == cf.encoder_att_type : - head = MultiInterAttentionHead( nhs, nhc, dims_embed, with_ln, dor, cf.with_qk_lnorm, - cf.grad_checkpointing, with_attention=cf.attention ) + head = MultiInterAttentionHead( nhs, len(field_info[1][2]), nhc, dims_embed, with_ln, dor, + cf.with_qk_lnorm, cf.grad_checkpointing, + with_attention=cf.attention ) elif 'axial' in cf.encoder_att_type : par = True if 'parallel' in cf.encoder_att_type else False head = MultiFieldAxialAttention( [3,2,1], dims_embed, nhs, nhc, par, dor) diff --git a/atmorep/utils/logger.py b/atmorep/utils/logger.py new file mode 100644 index 0000000..2ab1811 --- /dev/null +++ b/atmorep/utils/logger.py @@ -0,0 +1,22 @@ + +import logging +import pathlib +import os + +class RelPathFormatter(logging.Formatter): + def __init__(self, fmt, datefmt=None): + super().__init__(fmt, datefmt) + self.root_path = pathlib.Path(__file__).parent.parent.parent.resolve() + + def format(self, record): + # Replace the full pathname with the relative path + record.pathname = os.path.relpath(record.pathname, self.root_path) + return super().format(record) + +logger = logging.getLogger('atmorep') +logger.setLevel(logging.DEBUG) +ch = logging.StreamHandler() +formatter = RelPathFormatter('%(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s') +ch.setFormatter(formatter) +logger.handlers.clear() +logger.addHandler(ch) diff --git a/atmorep/utils/utils.py b/atmorep/utils/utils.py index a91ee0a..866e2d3 100644 --- a/atmorep/utils/utils.py +++ b/atmorep/utils/utils.py @@ -22,7 +22,7 @@ import wandb import code from calendar import monthrange - +#import properscoring as ps import numpy as np import torch.distributed as dist @@ -98,20 +98,22 @@ def write_json( self, wandb) : f.write( json_str) def load_json( self, wandb_id) : - if '/' in wandb_id : # assumed to be full path instead of just id fname = wandb_id else : fname = Path( config.path_models, 'id{}/model_id{}.json'.format( wandb_id, wandb_id)) - try : with open(fname, 'r') as f : json_str = f.readlines() - except IOError : + except (OSError, IOError) as e: # try path used for logging training results and checkpoints - fname = Path( config.path_results, '/models/id{}/model_id{}.json'.format( wandb_id, wandb_id)) - with open(fname, 'r') as f : - json_str = f.readlines() + try : + fname = Path( config.path_results, '/models/id{}/model_id{}.json'.format(wandb_id,wandb_id)) + with open(fname, 'r') as f : + json_str = f.readlines() + except (OSError, IOError) as e: + print( f'Could not find fname due to {e}. Aborting.') + quit() self.__dict__ = json.loads( json_str[0]) @@ -161,7 +163,9 @@ def setup_ddp( with_ddp = True) : rank = 0 size = 1 - if with_ddp : + master_node = os.environ.get('MASTER_ADDR', '-1') + + if with_ddp and (master_node != '-1'): local_rank = int(os.environ.get("SLURM_LOCALID")) ranks_per_node = int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] ) @@ -289,14 +293,19 @@ def tokenize( data, token_size = [-1,-1,-1]) : tok_tot_x = int( data_shape[-2] / token_size[1]) tok_tot_y = int( data_shape[-1] / token_size[2]) - if 4 == len(data_shape) : - t2 = torch.reshape( data, (-1, tok_tot_t, token_size[0], tok_tot_x, token_size[1], tok_tot_y, token_size[2])) - data_tokenized = torch.transpose(torch.transpose( torch.transpose( t2, 5, 4), 3, 2), 4, 3) + if 5 == len(data_shape) : + t2 = torch.reshape( data, (data.shape[0], data.shape[1], tok_tot_t, token_size[0], + tok_tot_x, token_size[1], tok_tot_y, token_size[2])) + data_tokenized = t2.permute( [0, 1, 2, 4, 6, 3, 5, 7]) + elif 4 == len(data_shape) : + t2 = torch.reshape( data, (-1, tok_tot_t, token_size[0], + tok_tot_x, token_size[1], tok_tot_y, token_size[2])) + data_tokenized = t2.permute( [0, 1, 3, 5, 4, 3, 6]) elif 3 == len(data_shape) : t2 = torch.reshape( data, (tok_tot_t, token_size[0], tok_tot_x, token_size[1], tok_tot_y, token_size[2])) data_tokenized = torch.transpose(torch.transpose( torch.transpose( t2, 4, 3), 2, 1), 3, 2) elif 2 == len(data_shape) : - t2 = torch.reshape( t1, (tok_tot_x, token_size[0], tok_tot_y, token_size[1])) + t2 = torch.reshape( data, (tok_tot_x, token_size[0], tok_tot_y, token_size[1])) data_tokenized = torch.transpose( t2, 1, 2) else : assert False @@ -345,10 +354,58 @@ def erf( x, mu=0., std_dev=1.) : val = c1 * ( 1./c2 - std_dev * torch.special.erf( (mu - x) / (c3 * std_dev) ) ) return val +######################################## +# def CRPS_ps( y, mu, std_dev) : +# val = ps.crps_gaussian(y.cpu().detach().numpy(), mu=mu.cpu().detach().numpy(), sig=std_dev.cpu().detach().numpy()) +# return torch.tensor(val) + def CRPS( y, mu, std_dev) : + # see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018. c1 = np.sqrt(1./np.pi) t1 = 2. * erf( (y-mu) / std_dev) - 1. t2 = 2. * Gaussian( (y-mu) / std_dev) val = std_dev * ( (y-mu)/std_dev * t1 + t2 - c1 ) return val + +######################################## +# def kernel_crps_ps( target, ens) : +# #breakpoint() +# val = ps.crps_ensemble(target.cpu().detach().numpy(), ens.permute([1,2,0]).cpu().detach().numpy()) +# return torch.tensor(val) + +def kernel_crps( target, ens, fair = True) : + #breakpoint() + ens_size = ens.shape[0] + mae = torch.cat( [(target - mem).abs().mean().unsqueeze(0) for mem in ens], 0).mean() + + if ens_size == 1: + return mae + + coef = -1.0 / (2.0 * ens_size * (ens_size - 1)) if fair else -1.0 / (2.0 * ens_size**2) + temp = [(p1 - p2).abs().sum() for p1 in ens for p2 in ens] + # breakpoint() + ens_var = coef * torch.tensor( [(p1 - p2).abs().sum() for p1 in ens for p2 in ens]).sum() + ens_var /= (ens.shape[1]*ens.shape[2]) + + return mae + ens_var + +######################################## + +def get_weights(lats_idx, lat_min = -90., lat_max = 90., reso = 0.25): + lat_range = lat_max - lat_min + bins = lat_range/reso+1 + + theta_weight = np.array([np.cos(w) for w in np.arange( lat_max * np.pi/lat_range , lat_min * np.pi/lat_range, -np.pi/bins)], dtype = np.float32) + + return theta_weight[lats_idx] + +######################################## + +def weighted_mse(x, target, weights): + return torch.sum(weights * (x - target) **2 )/torch.sum(weights) + +######################################## + +def check_num_samples(num_samples_validate, batch_size): + assert num_samples_validate // batch_size > 0, f"Num samples validate: {num_samples_validate} is smaller than batch size: {batch_size}. Please increase it." diff --git a/setup.py b/setup.py index 012c5b1..dd43a96 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ description='AtmoRep', packages=find_packages(), # if packages are available in a native form fo the host system then these should be used - install_requires=['torch', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo'], + install_requires=['torch>=2.3', 'numpy', 'matplotlib', 'zarr', 'pandas', 'typing_extensions', 'pathlib', 'wandb', 'cloudpickle', 'ecmwflibs', 'cfgrib', 'netcdf4', 'xarray', 'pytz', 'torchinfo', 'pytest', 'cfgrib'], data_files=[('./output', []), ('./logs', []), ('./results',[])], ) diff --git a/slurm_atmorep.sh b/slurm_atmorep.sh new file mode 100755 index 0000000..2f2cbfc --- /dev/null +++ b/slurm_atmorep.sh @@ -0,0 +1,48 @@ +#!/bin/bash -x +#SBATCH --account=ehpc03 +#SBATCH --time=0-24:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=20 +#SBATCH --gres=gpu:4 +#SBATCH --chdir=. +#SBATCH --qos=acc_ehpc +#SBATCH --output=logs/atmorep-%x.%j.out +#SBATCH --error=logs/atmorep-%x.%j.err + +# import modules +source pyenv/bin/activate + +export UCX_TLS="^cma" +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# so processes know who to talk to +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "MASTER_ADDR: $MASTER_ADDR" + +export NCCL_DEBUG=TRACE +echo "nccl_debug: $NCCL_DEBUG" + +# work-around for flipping links issue on JUWELS-BOOSTER +export NCCL_IB_TIMEOUT=250 +export UCX_RC_TIMEOUT=16s +export NCCL_IB_RETRY_CNT=50 + +echo "Starting job." +echo "Number of Nodes: $SLURM_JOB_NUM_NODES" +echo "Number of Tasks: $SLURM_NTASKS" +date + +export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} + +CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_train_${SLURM_JOBID} +mkdir ${CONFIG_DIR} +cp ${SLURM_SUBMIT_DIR}/atmorep/core/train.py ${CONFIG_DIR} +echo "${CONFIG_DIR}/train.py" +srun --label --cpu-bind=v --accel-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/train.py > output/output_${SLURM_JOBID}.txt + +echo "Finished job." +date diff --git a/slurm_atmorep_evaluate.sh b/slurm_atmorep_evaluate.sh new file mode 100755 index 0000000..32e160a --- /dev/null +++ b/slurm_atmorep_evaluate.sh @@ -0,0 +1,47 @@ +#!/bin/bash -x +#SBATCH --account=ehpc03 +#SBATCH --time=0-3:30:00 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=40 +#SBATCH --gres=gpu:2 +#SBATCH --chdir=. +#SBATCH --qos=acc_ehpc +#SBATCH --output=logs/atmorep-%x.%j.out +#SBATCH --error=logs/atmorep-%x.%j.err + +# import modules +source pyenv/bin/activate + +export UCX_TLS="^cma" +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +# so processes know who to talk to +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "MASTER_ADDR: $MASTER_ADDR" + +export NCCL_DEBUG=TRACE +echo "nccl_debug: $NCCL_DEBUG" + +# work-around for flipping links issue on JUWELS-BOOSTER +export NCCL_IB_TIMEOUT=250 +export UCX_RC_TIMEOUT=16s +export NCCL_IB_RETRY_CNT=50 + +echo "Starting job." +echo "Number of Nodes: $SLURM_JOB_NUM_NODES" +echo "Number of Tasks: $SLURM_NTASKS" +date + +export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} + +CONFIG_DIR=${SLURM_SUBMIT_DIR}/atmorep_eval_${SLURM_JOBID} +mkdir ${CONFIG_DIR} +cp ${SLURM_SUBMIT_DIR}/atmorep/core/evaluate.py ${CONFIG_DIR} +echo "${CONFIG_DIR}/evaluate.py" +srun --label --cpu-bind=v ${SLURM_SUBMIT_DIR}/pyenv/bin/python -u ${CONFIG_DIR}/evaluate.py > output/output_${SLURM_JOBID}.txt + +echo "Finished job." +date