Skip to content

Commit

Permalink
fix: do not silently scale learning rate with batch size
Browse files Browse the repository at this point in the history
Also add lr_schedule parameter to configuration files
  • Loading branch information
erwulff committed Jun 28, 2021
1 parent 62a6ba3 commit b896392
Show file tree
Hide file tree
Showing 14 changed files with 15 additions and 9 deletions.
11 changes: 2 additions & 9 deletions mlpf/tfmodel/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,6 @@ def main(args, yaml_path, config):
print("fallback to CPU", e)
strategy = tf.distribute.OneDeviceStrategy("cpu")
num_gpus = 0

actual_lr = global_batch_size*float(config['setup']['lr'])

This comment has been minimized.

Copy link
@erwulff

erwulff Jun 28, 2021

Author Owner

global_batch_size should be removed so that we have:
lr = float(config['setup']['lr'])


Xs = []
ygens = []
Expand All @@ -580,15 +578,10 @@ def main(args, yaml_path, config):
ygen_val = np.concatenate(ygens)
ycand_val = np.concatenate(ycands)

lr = global_batch_size*float(config['setup']['lr'])
with strategy.scope():
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
actual_lr,
decay_steps=10000,
decay_rate=0.99,
staircase=True
)
total_steps = n_epochs * n_train // global_batch_size
lr_schedule, optim_callbacks = get_lr_schedule(config, actual_lr, steps=total_steps)
lr_schedule, optim_callbacks = get_lr_schedule(config, lr, steps=total_steps)
opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
if config['setup']['dtype'] == 'float16':
model_dtype = tf.dtypes.float16
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-gnn-dense-big.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ setup:
sample_weights: inverse_sqrt
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn_dense
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-gnn-dense-focal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ setup:
sample_weights: none
trainable: all
classification_loss_type: sigmoid_focal_crossentropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn_dense
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-gnn-dense-transfer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ setup:
sample_weights: inverse_sqrt
trainable: transfer
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn_dense
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-gnn-dense.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ setup:
sample_weights: inverse_sqrt
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn_dense
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-gnn-skipconn-v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ setup:
sample_weights: inverse_sqrt
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-gnn-skipconn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ setup:
sample_weights: none
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-transformer-skipconn-gun.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ setup:
sample_weights: inverse_sqrt
trainable: all
multi_output: yes
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: transformer
Expand Down
1 change: 1 addition & 0 deletions parameters/cms-transformer-skipconn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ setup:
sample_weights: none
trainable: cls
multi_output: yes
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: transformer
Expand Down
1 change: 1 addition & 0 deletions parameters/delphes-gnn-skipconn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ setup:
trainable: all
multi_output: no
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn
Expand Down
1 change: 1 addition & 0 deletions parameters/delphes-transformer-skipconn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ setup:
sample_weights: none
trainable: all
multi_output: no
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: transformer
Expand Down
1 change: 1 addition & 0 deletions parameters/test-cms-v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ setup:
sample_weights: none
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn_dense
Expand Down
1 change: 1 addition & 0 deletions parameters/test-cms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ setup:
sample_weights: none
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn
Expand Down
1 change: 1 addition & 0 deletions parameters/test-delphes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ setup:
sample_weights: none
trainable: all
classification_loss_type: categorical_cross_entropy
lr_schedule: exponentialdecay # exponentialdecay, onecycle

parameters:
model: gnn
Expand Down

0 comments on commit b896392

Please sign in to comment.