Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge all new developments to main #27

Merged
merged 77 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
ed65f39
- Added support for separate data directories for separate datasets.
Mar 8, 2024
38ab84b
- Switched to training data in zarr
Mar 11, 2024
29dd3e6
normalization and tokenization
iluise Mar 19, 2024
dbe1bbc
working bert.py
iluise Mar 20, 2024
967013d
first working version with sfc data + source info
iluise Mar 20, 2024
490df64
simplify log_validate
iluise Mar 21, 2024
2f7ecc3
restructure data_writer
iluise Mar 22, 2024
36cf3e9
add one comment
iluise Mar 22, 2024
16e160a
Removed duplicate fields_tokens_masked_idx.
Mar 25, 2024
27e725b
modify log_attention
iluise Mar 25, 2024
5886b3c
Fixed tokenize for 2D input (but not tested or used).
Mar 25, 2024
1c65c7a
Cleaned up some details in log_validate_BERT().
Mar 25, 2024
fde838b
merge
iluise Mar 25, 2024
67e9a7e
fix loss
iluise Mar 28, 2024
092478c
Re-enabled standard data loaders.
Apr 3, 2024
5b6641e
Removed stale code.
Apr 3, 2024
6a2c25c
Adapted forecast to only having tokens_mask_idx_list
Apr 3, 2024
b7f380a
Various fixes and changing, in particular enabling again global fore…
Apr 3, 2024
9be35b3
commit temporary code
iluise Apr 4, 2024
2d86650
running global forecast evaluation
iluise Apr 5, 2024
4c1f8ca
Changed set_global to numpy arrays for consistency.
Apr 8, 2024
1b9e20d
Fixed bug in returned idx_masked for BERT.
Apr 8, 2024
7c6413f
- Fixed issues in log_BERT due to bug fixing fo masked_idx.
Apr 8, 2024
e0f72d1
validated version against main
iluise Apr 9, 2024
6ba6521
- Fixed handling of shuffle()
Apr 9, 2024
296595a
fix bert_strategy
iluise Apr 9, 2024
f28e1c9
fix temporal interpol, fix overlap, remove evaluate
iluise Apr 11, 2024
f2b0646
Fixed hard coded path.
clessig Apr 13, 2024
008031b
Removed load_data for training.
Apr 13, 2024
6d523be
Merge branch 'iluise-sfc-fields' of https://github.com/clessig/atmore…
Apr 13, 2024
f9b609a
new normalization from zarr
iluise Apr 16, 2024
ecc6621
validated status
iluise Apr 16, 2024
45d4fdc
validated new normalization
iluise Apr 16, 2024
e56952f
temporal interpolation for multiple time steps
iluise Apr 19, 2024
afbdf04
Adding sample/sec to console output.
Apr 22, 2024
5f5efa5
Merge branch 'iluise-sfc-fields' of https://github.com/clessig/atmore…
Apr 22, 2024
77a2aac
Removed variable length batch size.
Apr 22, 2024
ffe7e18
Fixes for new/old ordering of fields.
Apr 24, 2024
2c73df2
Implemented efficient fused flash-attention (i.e. flash attention wit…
May 21, 2024
92b097b
Fixed bug with multi-year training data ranges.
May 23, 2024
6831ffc
delete unused files
iluise May 27, 2024
e3a7faa
first implem of weight_translate + time bug fix
iluise Jun 6, 2024
ecee5b5
validate MultiCrossAttentionHead
iluise Jun 7, 2024
705f03e
prepare full config example for Epicure
Jun 10, 2024
f211001
comment out requirements for wheel installation
Jun 10, 2024
4fcff35
increase n workers evaluate
Jun 12, 2024
88ee5c0
fix path in evaluate
Jun 12, 2024
e286c7e
new wip weighted area
Jul 3, 2024
9c49ded
final weighted average LOSS
Jul 3, 2024
a9a85c5
clean code version. time still off.
Jul 23, 2024
46b22b8
add test
Jul 24, 2024
fe05e28
add test
Jul 24, 2024
cf3a165
include validation tests within evaluate.py
Jul 25, 2024
1f9489e
tests now working in BERT, forecast and global_forecast mode
Jul 25, 2024
3868ab4
restore values in evaluate
Jul 25, 2024
7863acc
Reenabled packages.
clessig Jul 26, 2024
104a23d
Merge branch 'atmorep-dev' of github.com:clessig/atmorep into atmorep…
clessig Jul 26, 2024
4c6111c
Added version requirement for torch.
clessig Jul 26, 2024
56a6a39
Factored MLP into a separate file.
clessig Jul 29, 2024
7d8e24c
Adding logging code but still needs to be used.
clessig Jul 29, 2024
9f69a5d
Removing used file and functionality (which has moved to multifield_d…
clessig Jul 29, 2024
c8b180d
- Added optimization suggested by BSC (although no performance improv…
clessig Jul 29, 2024
0dd99a7
Cleaned up unused parameters and dependencies.
clessig Jul 29, 2024
3dd417c
Fixed partial merge
Jul 29, 2024
179b95f
Cleaned up code.
Jul 29, 2024
679adb3
validated code
Jul 30, 2024
224bf91
Fixed bug when number of samples is larger than available samples.
clessig Aug 5, 2024
5c77311
restore settings for batch size 96
Aug 7, 2024
debb1e8
Merge branch 'develop' of https://github.com/clessig/atmorep into dev…
Aug 7, 2024
11f5d7e
restore paths
Aug 7, 2024
2afc62b
add BSC Epicure team suggestions
Aug 7, 2024
df7b23c
fix local norm
Aug 7, 2024
a73be45
fix evaluate.py for years_val
Aug 8, 2024
ba60b71
answer review comments
Aug 12, 2024
db4d77f
Merge pull request #24 from clessig/iluise/head
iluise Aug 12, 2024
f38ddc5
resolve conflicts with develop
Aug 12, 2024
e3412a6
Merge pull request #28 from clessig/iluise/head
clessig Aug 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions atmorep/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

fpath = os.path.dirname(os.path.realpath(__file__))

year_base = 1979
year_last = 2022

path_models = Path( fpath, '../../models/')
path_results = Path( fpath, '../../results/')
path_data = Path( fpath, '../../data/')
path_results = Path( fpath, '../../results')
path_plots = Path( fpath, '../results/plots/')

grib_index = { 'vorticity' : 'vo', 'divergence' : 'd', 'geopotential' : 'z',
Expand Down
194 changes: 103 additions & 91 deletions atmorep/core/atmorep_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from atmorep.transformer.transformer_decoder import TransformerDecoder
from atmorep.transformer.tail_ensemble import TailEnsemble


####################################################################################################
class AtmoRepData( torch.nn.Module) :

Expand All @@ -53,37 +54,6 @@ def __init__( self, net) :
self.rng_seed = net.cf.rng_seed
if not self.rng_seed :
self.rng_seed = int(torch.randint( 100000000, (1,)))

###################################################
def load_data( self, mode : NetMode, batch_size = -1, num_loader_workers = -1) :
'''Load data'''

cf = self.net.cf

if batch_size < 0 :
batch_size = cf.batch_size_max
if num_loader_workers < 0 :
num_loader_workers = cf.num_loader_workers

if mode == NetMode.train :
self.data_loader_train = self._load_data( self.dataset_train, batch_size, num_loader_workers)
elif mode == NetMode.test :
batch_size = cf.batch_size_test
self.data_loader_test = self._load_data( self.dataset_test, batch_size, num_loader_workers)
else :
assert False

###################################################
def _load_data( self, dataset, batch_size, num_loader_workers) :
'''Private implementation for load'''

dataset.load_data( batch_size)

loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False,
'num_workers': num_loader_workers, 'pin_memory': True}
data_loader = torch.utils.data.DataLoader( dataset, **loader_params, sampler = None)

return data_loader

###################################################
def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_workers = -1) :
Expand All @@ -94,7 +64,7 @@ def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_worke

dataset = self.dataset_train if mode == NetMode.train else self.dataset_test
dataset.set_data( times_pos, batch_size)

self._set_data( dataset, mode, batch_size, num_loader_workers)

###################################################
Expand All @@ -103,7 +73,6 @@ def set_global( self, mode : NetMode, times, batch_size = -1, num_loader_workers
cf = self.net.cf
if batch_size < 0 :
batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test

dataset = self.dataset_train if mode == NetMode.train else self.dataset_test
dataset.set_global( times, batch_size, cf.token_overlap)

Expand Down Expand Up @@ -143,7 +112,7 @@ def _set_data( self, dataset, mode : NetMode, batch_size = -1, loader_workers =
assert False

###################################################
def normalizer( self, field, vl_idx) :
def normalizer( self, field, vl_idx, lats_idx, lons_idx ) :

if isinstance( field, str) :
for fidx, field_info in enumerate(self.cf.fields) :
Expand All @@ -153,12 +122,15 @@ def normalizer( self, field, vl_idx) :
normalizer = self.dataset_train.datasets[fidx].normalizer

elif isinstance( field, int) :
normalizer = self.dataset_train.datasets[field][vl_idx].normalizer

normalizer = self.dataset_train.normalizers[field][vl_idx]
if len(normalizer.shape) > 2:
normalizer = np.take( np.take( normalizer, lats_idx, -2), lons_idx, -1)
else :
assert False, 'invalid argument type (has to be index to cf.fields or field name)'

year_base = self.dataset_train.year_base

return normalizer
return normalizer, year_base

###################################################
def mode( self, mode : NetMode) :
Expand Down Expand Up @@ -193,8 +165,8 @@ def forward( self, xin) :
return pred

###################################################
def get_attention( self, xin): #, field_idx) :
attn = self.net.get_attention( xin) #, field_idx)
def get_attention( self, xin) :
attn = self.net.get_attention( xin)
return attn

###################################################
Expand All @@ -208,40 +180,26 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non
self.pre_batch_targets = pre_batch_targets

cf = self.net.cf
self.dataset_train = MultifieldDataSampler( cf.data_dir, cf.years_train, cf.fields,
batch_size = cf.batch_size_start,
num_t_samples = cf.num_t_samples,
num_patches_per_t = cf.num_patches_per_t_train,
num_load = cf.num_files_train,
pre_batch = self.pre_batch,
rng_seed = self.rng_seed,
file_shape = cf.file_shape,
smoothing = cf.data_smoothing,
level_type = cf.level_type,
file_format = cf.file_format,
month = cf.month,
time_sampling = cf.time_sampling,
geo_range = cf.geo_range_sampling,
fields_targets = cf.fields_targets,
pre_batch_targets = self.pre_batch_targets )

self.dataset_test = MultifieldDataSampler( cf.data_dir, cf.years_test, cf.fields,
batch_size = cf.batch_size_test,
num_t_samples = cf.num_t_samples,
num_patches_per_t = cf.num_patches_per_t_test,
num_load = cf.num_files_test,
pre_batch = self.pre_batch,
rng_seed = self.rng_seed,
file_shape = cf.file_shape,
smoothing = cf.data_smoothing,
level_type = cf.level_type,
file_format = cf.file_format,
month = cf.month,
time_sampling = cf.time_sampling,
geo_range = cf.geo_range_sampling,
lat_sampling_weighted = cf.lat_sampling_weighted,
fields_targets = cf.fields_targets,
pre_batch_targets = self.pre_batch_targets )
loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False,
'num_workers': cf.num_loader_workers, 'pin_memory': True}

self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train,
cf.batch_size,
pre_batch, cf.n_size, cf.num_samples_per_epoch,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params,
sampler = None)

self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val,
cf.batch_size_validation,
pre_batch, cf.n_size, cf.num_samples_validate,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params,
sampler = None)

return self

Expand All @@ -261,7 +219,6 @@ def create( self, devices, load_pretrained=True) :

cf = self.cf
self.devices = devices
size_token_info = 6
self.fields_coupling_idx = []

self.fields_index = {}
Expand Down Expand Up @@ -294,17 +251,9 @@ def create( self, devices, load_pretrained=True) :

self.embeds = torch.nn.ModuleList()
self.encoders = torch.nn.ModuleList()
self.masks = torch.nn.ParameterList()

for field_idx, field_info in enumerate(cf.fields) :

# learnabl class token
if cf.learnable_mask :
mask = torch.nn.Parameter( 0.1 * torch.randn( np.prod( field_info[4]), requires_grad=True))
self.masks.append( mask.to(devices[0]))
else :
self.masks.append( None)

# encoder
self.encoders.append( TransformerEncoder( cf, field_idx, True).create())
# load pre-trained model if specified
Expand Down Expand Up @@ -356,11 +305,10 @@ def create( self, devices, load_pretrained=True) :
device = self.devices[0]
if len(field_info[1]) > 3 :
assert field_info[1][3] < 4, 'Only single node model parallelism supported'
print(devices, field_info[1][3])
assert field_info[1][3] < len(devices), 'Per field device id larger than max devices'
device = self.devices[ field_info[1][3] ]
# set device
if self.masks[field_idx] != None :
self.masks[field_idx].to(device)
self.embeds[field_idx].to(device)
self.encoders[field_idx].to(device)

Expand Down Expand Up @@ -418,6 +366,68 @@ def load_block( self, field_info, block_name, block ) :
print( 'Loaded {} for {} from id = {} (ignoring/missing {} elements).'.format( block_name,
field_info[0], field_info[1][4][0], len(mkeys) ) )

###################################################
def translate_weights(self, mloaded, mkeys, ukeys):
'''
Function used for backward compatibility
'''
cf = self.cf

#encoder:
for layer in range(cf.encoder_num_layers) :

#shape([16, 3, 128, 2048])
mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]])
mloaded[f'encoders.0.heads.{layer}.proj_heads.weight'] = mw

for head in range(cf.encoder_num_heads):
del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight']

#cross attention
if f'encoders.0.heads.{layer}.heads_other.0.proj_qs.weight' in ukeys:
mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]])

for i in range(cf.encoder_num_heads):
del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight']

else:
dim_mw = self.encoders[0].heads[0].proj_heads_other[0].weight.shape
mw = torch.tensor(np.zeros(dim_mw))

mloaded[f'encoders.0.heads.{layer}.proj_heads_other.0.weight'] = mw

#decoder
for iblock in range(0, 19, 2) :
mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads.{head}.proj_{k}.weight'] for head in range(8) for k in ["qs", "ks", "vs"]])
mloaded[f'decoders.0.blocks.{iblock}.proj_heads.weight'] = mw

qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_qs.weight'] for head in range(8)]
mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_{k}.weight'] for head in range(8) for k in ["ks", "vs"]])

mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_q.weight'] = torch.cat([*qs])
mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_kv.weight'] = mw

#self.num_samples_validate
decoder_dim = self.decoders[0].blocks[iblock].ln_q.weight.shape #128
mloaded[f'decoders.0.blocks.{iblock}.ln_q.weight'] = torch.tensor(np.ones(decoder_dim))
mloaded[f'decoders.0.blocks.{iblock}.ln_k.weight'] = torch.tensor(np.ones(decoder_dim))
mloaded[f'decoders.0.blocks.{iblock}.ln_q.bias'] = torch.tensor(np.ones(decoder_dim))
mloaded[f'decoders.0.blocks.{iblock}.ln_k.bias'] = torch.tensor(np.ones(decoder_dim))

for i in range(8):
del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight']

return mloaded

###################################################
@staticmethod
def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) :
Expand All @@ -429,15 +439,18 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) :

model = AtmoRep( cf).create( devices, load_pretrained=False)
mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) )
mkeys, _ = model.load_state_dict( mloaded, False )
mkeys, ukeys = model.load_state_dict( mloaded, False )
if (f'encoders.0.heads.0.proj_heads.weight') in mkeys:
mloaded = model.translate_weights(mloaded, mkeys, ukeys)
mkeys, ukeys = model.load_state_dict( mloaded, False )

if len(mkeys) > 0 :
print( f'Loaded AtmoRep: ignoring {len(mkeys)} elements: {mkeys}')

# TODO: remove, only for backward
if model.embeds_token_info[0].weight.abs().max() == 0. :
model.embeds_token_info = torch.nn.ModuleList()

return model

###################################################
Expand Down Expand Up @@ -474,8 +487,9 @@ def forward( self, xin) :

# embedding
cf = self.cf

fields_embed = self.get_fields_embed(xin)

# attention maps (if requested)
atts = [ [] for _ in cf.fields ]

Expand Down Expand Up @@ -528,16 +542,14 @@ def forward_encoder_block( self, iblock, fields_embed) :
return fields_embed_cur, atts

###################################################

def get_fields_embed( self, xin ) :
cf = self.cf
if 0 == len(self.embeds_token_info) : # TODO: only for backward compatibility, remove
emb_net_ti = self.embed_token_info
return [prepare_token( field_data, emb_net, emb_net_ti, cf.with_cls )
return [prepare_token( field_data, emb_net, emb_net_ti )
for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))]
else :
embs_net_ti = self.embeds_token_info
return [prepare_token( field_data, emb_net, embs_net_ti[fidx], cf.with_cls )
return [prepare_token( field_data, emb_net, embs_net_ti[fidx] )
for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))]

###################################################
Expand Down
Loading