diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index 77cc94a78..64e5d7b96 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -554,8 +554,6 @@ def main(args, yaml_path, config): print("fallback to CPU", e) strategy = tf.distribute.OneDeviceStrategy("cpu") num_gpus = 0 - - actual_lr = global_batch_size*float(config['setup']['lr']) Xs = [] ygens = [] @@ -580,15 +578,10 @@ def main(args, yaml_path, config): ygen_val = np.concatenate(ygens) ycand_val = np.concatenate(ycands) + lr = global_batch_size*float(config['setup']['lr']) with strategy.scope(): - lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( - actual_lr, - decay_steps=10000, - decay_rate=0.99, - staircase=True - ) total_steps = n_epochs * n_train // global_batch_size - lr_schedule, optim_callbacks = get_lr_schedule(config, actual_lr, steps=total_steps) + lr_schedule, optim_callbacks = get_lr_schedule(config, lr, steps=total_steps) opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule) if config['setup']['dtype'] == 'float16': model_dtype = tf.dtypes.float16 diff --git a/parameters/cms-gnn-dense-big.yaml b/parameters/cms-gnn-dense-big.yaml index 3a03e8f20..6c4c62059 100644 --- a/parameters/cms-gnn-dense-big.yaml +++ b/parameters/cms-gnn-dense-big.yaml @@ -51,6 +51,7 @@ setup: sample_weights: inverse_sqrt trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense diff --git a/parameters/cms-gnn-dense-focal.yaml b/parameters/cms-gnn-dense-focal.yaml index 2f715fd2c..15ce45126 100644 --- a/parameters/cms-gnn-dense-focal.yaml +++ b/parameters/cms-gnn-dense-focal.yaml @@ -52,6 +52,7 @@ setup: sample_weights: none trainable: all classification_loss_type: sigmoid_focal_crossentropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense diff --git a/parameters/cms-gnn-dense-transfer.yaml b/parameters/cms-gnn-dense-transfer.yaml index e55cc9407..688922a84 100644 --- a/parameters/cms-gnn-dense-transfer.yaml +++ b/parameters/cms-gnn-dense-transfer.yaml @@ -51,6 +51,7 @@ setup: sample_weights: inverse_sqrt trainable: transfer classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense diff --git a/parameters/cms-gnn-dense.yaml b/parameters/cms-gnn-dense.yaml index 6089456a9..ab0087373 100644 --- a/parameters/cms-gnn-dense.yaml +++ b/parameters/cms-gnn-dense.yaml @@ -53,6 +53,7 @@ setup: sample_weights: inverse_sqrt trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense diff --git a/parameters/cms-gnn-skipconn-v2.yaml b/parameters/cms-gnn-skipconn-v2.yaml index c13f7d854..0bb9c9220 100644 --- a/parameters/cms-gnn-skipconn-v2.yaml +++ b/parameters/cms-gnn-skipconn-v2.yaml @@ -51,6 +51,7 @@ setup: sample_weights: inverse_sqrt trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn diff --git a/parameters/cms-gnn-skipconn.yaml b/parameters/cms-gnn-skipconn.yaml index f0c9aa51e..1f23797b7 100644 --- a/parameters/cms-gnn-skipconn.yaml +++ b/parameters/cms-gnn-skipconn.yaml @@ -51,6 +51,7 @@ setup: sample_weights: none trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn diff --git a/parameters/cms-transformer-skipconn-gun.yaml b/parameters/cms-transformer-skipconn-gun.yaml index d079d71f2..180cb513d 100644 --- a/parameters/cms-transformer-skipconn-gun.yaml +++ b/parameters/cms-transformer-skipconn-gun.yaml @@ -52,6 +52,7 @@ setup: sample_weights: inverse_sqrt trainable: all multi_output: yes + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: transformer diff --git a/parameters/cms-transformer-skipconn.yaml b/parameters/cms-transformer-skipconn.yaml index 767f34416..f8ea796b9 100644 --- a/parameters/cms-transformer-skipconn.yaml +++ b/parameters/cms-transformer-skipconn.yaml @@ -50,6 +50,7 @@ setup: sample_weights: none trainable: cls multi_output: yes + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: transformer diff --git a/parameters/delphes-gnn-skipconn.yaml b/parameters/delphes-gnn-skipconn.yaml index 88fd5f189..d73066042 100644 --- a/parameters/delphes-gnn-skipconn.yaml +++ b/parameters/delphes-gnn-skipconn.yaml @@ -41,6 +41,7 @@ setup: trainable: all multi_output: no classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn diff --git a/parameters/delphes-transformer-skipconn.yaml b/parameters/delphes-transformer-skipconn.yaml index f687fd63e..9f3113c34 100644 --- a/parameters/delphes-transformer-skipconn.yaml +++ b/parameters/delphes-transformer-skipconn.yaml @@ -39,6 +39,7 @@ setup: sample_weights: none trainable: all multi_output: no + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: transformer diff --git a/parameters/test-cms-v2.yaml b/parameters/test-cms-v2.yaml index 25f9a0a5e..763b5e0ca 100644 --- a/parameters/test-cms-v2.yaml +++ b/parameters/test-cms-v2.yaml @@ -39,6 +39,7 @@ setup: sample_weights: none trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense diff --git a/parameters/test-cms.yaml b/parameters/test-cms.yaml index 939e37fc8..175c84686 100644 --- a/parameters/test-cms.yaml +++ b/parameters/test-cms.yaml @@ -39,6 +39,7 @@ setup: sample_weights: none trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn diff --git a/parameters/test-delphes.yaml b/parameters/test-delphes.yaml index 058836b7e..1c36387a5 100644 --- a/parameters/test-delphes.yaml +++ b/parameters/test-delphes.yaml @@ -38,6 +38,7 @@ setup: sample_weights: none trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn