Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix optimizer save/load, add mpnn config, get f16 training to work #108

Merged
merged 7 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def main():
@click.option("--customize", help="customization function", type=str, default=None)
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq, customize):

#tf.debugging.enable_check_numerics()

"""Train a model defined by config"""
config_file_path = config
Expand All @@ -117,10 +118,7 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
if customize:
config = customization_functions[customize](config)

if recreate or (weights is None):
outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node())
else:
outdir = str(Path(weights).parent.parent)
outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node())

try:
from comet_ml import Experiment
Expand Down Expand Up @@ -193,7 +191,10 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
# We need to load the weights in the same trainable configuration as the model was set up
configure_model_weights(model, config["setup"].get("weights_config", "all"))
model.load_weights(weights, by_name=True)
loaded_opt = pickle.load(open(weights.replace("hdf5", "pkl").replace("/weights-", "/opt-"), "rb"))
opt_weight_file = weights.replace("hdf5", "pkl").replace("/weights-", "/opt-")
if os.path.isfile(opt_weight_file):
loaded_opt = pickle.load(open(opt_weight_file, "rb"))

initial_epoch = int(weights.split("/")[-1].split("-")[1])
model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]))

Expand Down Expand Up @@ -228,8 +229,8 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,

model.summary()

#Load the optimizer weights
if weights:
#Set the optimizer weights
if loaded_opt:
def model_weight_setting():
grad_vars = model.trainable_weights
zero_grads = [tf.zeros_like(w) for w in grad_vars]
Expand Down
5 changes: 4 additions & 1 deletion mlpf/tallinn/cms-gen.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gpus 8
#SBATCH --gpus 4
#SBATCH --mem-per-gpu=8G

IMG=/home/software/singularity/tf-2.8.0.simg
Expand All @@ -11,3 +11,6 @@ singularity exec -B /scratch-persistent --nv \
--env PYTHONPATH=hep_tfds \
--env TFDS_DATA_DIR=/scratch-persistent/joosep/tensorflow_datasets \
$IMG python mlpf/pipeline.py train -c parameters/cms-gen.yaml --plot-freq 100

# -c experiments/cms-gen_20220503_145445_570900.gpu0.local/config.yaml \
# -w experiments/cms-gen_20220503_145445_570900.gpu0.local/weights/weights-100-2.682420.hdf5
4 changes: 4 additions & 0 deletions mlpf/tfmodel/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, *args, **kwargs):
def _collect_learning_rate(self, logs):
logs = logs or {}
if hasattr(self.model.optimizer, "lr"):

lr_schedule = getattr(self.model.optimizer, "lr", None)
if isinstance(lr_schedule, tf.keras.optimizers.schedules.LearningRateSchedule):
logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(self.model.optimizer.iterations)))
Expand All @@ -39,6 +40,9 @@ def _collect_learning_rate(self, logs):
if isinstance(self.model.optimizer, tf.keras.optimizers.Adam):
logs.update({"adam_beta_1": np.float64(tf.keras.backend.eval(self.model.optimizer.beta_1))})

if hasattr(self.model.optimizer, "loss_scale"):
logs.update({"loss_scale": np.float64(self.model.optimizer.loss_scale.numpy())})

return logs

def on_epoch_end(self, epoch, logs):
Expand Down
104 changes: 57 additions & 47 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

import tensorflow as tf

import numpy as np
from numpy.lib.recfunctions import append_fields

regularizer_weight = 0.0

def split_indices_to_bins(cmul, nbins, bin_size):
Expand All @@ -20,7 +17,7 @@ def split_indices_to_bins_batch(cmul, nbins, bin_size, msk):
return bins_split


def pairwise_gaussian_dist(A, B):
def pairwise_l2_dist(A, B):
na = tf.reduce_sum(tf.square(A), -1)
nb = tf.reduce_sum(tf.square(B), -1)

Expand All @@ -33,6 +30,12 @@ def pairwise_gaussian_dist(A, B):
D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 1e-6))
return D

def pairwise_l1_dist(A, B):
na = tf.expand_dims(A, -2)
nb = tf.expand_dims(B, -3)
D = tf.abs(tf.reduce_sum(na-nb, axis=-1))
return D

def pairwise_learnable_dist(A, B, ffn, training=False):
shp = tf.shape(A)

Expand All @@ -45,7 +48,7 @@ def pairwise_learnable_dist(A, B, ffn, training=False):
tf.gather_nd(B, inds2)], axis=-1
) #(batch, bin, elem, elem, feat)

#run a feedforward net on (src, dst) -> 1
#run a feedforward net on ffn(src, dst) -> output_dim
res_transformed = ffn(res, training=training)

return res_transformed
Expand Down Expand Up @@ -198,7 +201,8 @@ def call(self, inputs):

#compute the normalization of the adjacency matrix
if self.normalize_degrees:
in_degrees = tf.clip_by_value(tf.reduce_sum(tf.abs(adj), axis=-1), 0, 1000)
#in_degrees = tf.clip_by_value(tf.reduce_sum(tf.abs(adj), axis=-1), 0, 1000)
in_degrees = tf.reduce_sum(tf.abs(adj), axis=-1)

#add epsilon to prevent numerical issues from 1/sqrt(x)
norm = tf.expand_dims(tf.pow(in_degrees + 1e-6, -0.5), -1)*msk
Expand All @@ -217,7 +221,6 @@ def call(self, inputs):

class NodeMessageLearnable(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):

self.output_dim = kwargs.pop("output_dim")
self.hidden_dim = kwargs.pop("hidden_dim")
self.num_layers = kwargs.pop("num_layers")
Expand All @@ -235,19 +238,15 @@ def __init__(self, *args, **kwargs):
def call(self, inputs):
x, adj, msk = inputs

#collect incoming messages
avg_message_dst = tf.reduce_mean(adj, axis=-2)
#collect incoming messages (batch, bins, elems, elems, msg_dim) -> (batch, bins, elems, msg_dim)
max_message_dst = tf.reduce_max(adj, axis=-2)
min_message_dst = tf.reduce_min(adj, axis=-2)

#collect outgoing messages
avg_message_src = tf.reduce_mean(adj, axis=-3)
#collect outgoing messages (batch, bins, elems, elems, msg_dim) -> (batch, bins, elems, msg_dim)
max_message_src = tf.reduce_max(adj, axis=-3)
min_message_src = tf.reduce_min(adj, axis=-3)

#node update
x2 = tf.concat([x, avg_message_dst, max_message_dst, min_message_dst, avg_message_src, max_message_src, min_message_src], axis=-1)
return self.activation(self.ffn(x2))*msk
#node update (batch, bins, elems, elems, elem_dim + msg_dim + msg_dim)
x2 = tf.concat([x, max_message_dst, max_message_src], axis=-1)
return tf.cast(self.activation(self.ffn(x2)), x.dtype)

def point_wise_feed_forward_network(d_model, dff, name, num_layers=1, activation='elu', dtype=tf.dtypes.float32, dim_decrease=False, dropout=0.0):

Expand Down Expand Up @@ -287,9 +286,17 @@ def get_message_layer(config_dict, name):
return conv_cls(name=name, **config_dict)

class NodePairGaussianKernel(tf.keras.layers.Layer):
def __init__(self, clip_value_low=0.0, dist_mult=0.1, **kwargs):
self.clip_value_low = clip_value_low
self.dist_mult = dist_mult
def __init__(self, **kwargs):
self.clip_value_low = kwargs.pop("clip_value_low", 0.0)
self.dist_mult = kwargs.pop("dist_mult", 0.1)

dist_norm = kwargs.pop("dist_norm", "l2")
if dist_norm == "l1":
self.dist_norm = pairwise_l1_dist
elif dist_norm == "l2":
self.dist_norm = pairwise_l2_dist
else:
raise Exception("Unkown dist_norm: {}".format(dist_norm))
super(NodePairGaussianKernel, self).__init__(**kwargs)

"""
Expand All @@ -299,13 +306,13 @@ def __init__(self, clip_value_low=0.0, dist_mult=0.1, **kwargs):
"""
def call(self, x_msg_binned, msk, training=False):
x = x_msg_binned*msk
dm = tf.expand_dims(pairwise_gaussian_dist(x, x), axis=-1)
dm = tf.expand_dims(self.dist_norm(x, x), axis=-1)
dm = tf.exp(-self.dist_mult*dm)
dm = tf.clip_by_value(dm, self.clip_value_low, 1)
return dm

class NodePairTrainableKernel(tf.keras.layers.Layer):
def __init__(self, output_dim=32, hidden_dim_node=256, hidden_dim_pair=32, num_layers=2, activation="elu", **kwargs):
def __init__(self, output_dim=4, hidden_dim_node=128, hidden_dim_pair=32, num_layers=1, activation="elu", **kwargs):
self.output_dim = output_dim
self.hidden_dim_node = hidden_dim_node
self.hidden_dim_pair = hidden_dim_pair
Expand Down Expand Up @@ -337,11 +344,10 @@ def __init__(self, output_dim=32, hidden_dim_node=256, hidden_dim_pair=32, num_l
"""
def call(self, x_msg_binned, msk, training=False):

node_proj = self.activation(self.ffn_node(x_msg_binned))*msk
node_proj = self.activation(self.ffn_node(x_msg_binned))

dm = pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training)
dm = self.activation(dm)
return tf.reduce_max(dm, axis=-1, keepdims=True)
dm = tf.cast(pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training), x_msg_binned.dtype)
return dm

def build_kernel_from_conf(kernel_dict, name):
kernel_dict = kernel_dict.copy()
Expand Down Expand Up @@ -399,14 +405,14 @@ def call(self, x_msg, x_node, msk, training=False):

#Run the node-to-node kernel (distance computation / graph building / attention)
dm = self.kernel(x_msg_binned, msk_f_binned, training=training)

#remove the masked points row-wise and column-wise
msk_f_binned_squeeze = tf.squeeze(msk_f_binned, axis=-1)
shp_dm = tf.shape(dm)
rshp_row = [shp_dm[0], shp_dm[1], shp_dm[2], 1, 1]
rshp_col = [shp_dm[0], shp_dm[1], 1, shp_dm[3], 1]
msk_row = tf.reshape(msk_f_binned_squeeze, rshp_row)
msk_col = tf.reshape(msk_f_binned_squeeze, rshp_col)
msk_row = tf.cast(tf.reshape(msk_f_binned_squeeze, rshp_row), dm.dtype)
msk_col = tf.cast(tf.reshape(msk_f_binned_squeeze, rshp_col), dm.dtype)
dm = tf.math.multiply(dm, msk_row)
dm = tf.math.multiply(dm, msk_col)

Expand Down Expand Up @@ -611,9 +617,9 @@ def call(self, args, training=False):

#In case of a multimodal prediction, weight the per-class energy predictions by the approximately one-hot vector
if self.energy_multimodal:
pred_energy = tf.reduce_sum(out_id_hard_softmax*pred_energy_corr, axis=-1, keepdims=True)
pred_energy = orig_energy+tf.reduce_sum(out_id_hard_softmax*pred_energy_corr, axis=-1, keepdims=True)
else:
pred_energy = pred_energy_corr
pred_energy = orig_energy+pred_energy_corr

#compute pt=E/cosh(eta)
orig_pt = tf.stop_gradient(pred_energy/tf.math.cosh(tf.clip_by_value(pred_eta, -8, 8)))
Expand Down Expand Up @@ -689,7 +695,6 @@ def __init__(self, *args, **kwargs):

if self.do_layernorm:
self.layernorm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm2")

#self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
self.ffn_dist = point_wise_feed_forward_network(
Expand Down Expand Up @@ -728,18 +733,14 @@ def call(self, x, msk, training=False):

#compute node features for graph building
x_dist = self.dist_activation(self.ffn_dist(x, training=training))
if self.do_layernorm:
x_dist = self.layernorm2(x_dist, training=training)

#compute the element-to-element messages / distance matrix / graph structure
if self.do_lsh:
bins_split, x, dm, msk_f = self.message_building_layer(x_dist, x, msk)

#bins_split: (FIXME)
#x: (batch, bin, elem, node_feature)
#dm: (batch, bin, elem, elem, pair_feature)
#msk_f: (batch, bin, elem, elem, 1)

else:
dm = self.message_building_layer(x_dist, msk)
msk_f = tf.expand_dims(tf.cast(msk, x.dtype), axis=-1)
Expand Down Expand Up @@ -767,8 +768,8 @@ def __init__(self,
multi_output=False,
num_input_classes=8,
num_output_classes=3,
num_graph_layers_common=1,
num_graph_layers_energy=1,
num_graph_layers_id=1,
num_graph_layers_reg=1,
input_encoding="cms",
skip_connection=True,
graph_kernel={},
Expand Down Expand Up @@ -808,8 +809,8 @@ def __init__(self,
elif input_encoding == "default":
self.enc = InputEncoding(num_input_classes)

self.cg_id = [CombinedGraphLayer(name="cg_id_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)]
self.cg_reg = [CombinedGraphLayer(name="cg_reg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)]
self.cg_id = [CombinedGraphLayer(name="cg_id_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_id)]
self.cg_reg = [CombinedGraphLayer(name="cg_reg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_reg)]

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
Expand Down Expand Up @@ -848,15 +849,18 @@ def call(self, inputs, training=False):
debugging_data[cg.name] = enc_all

if self.node_update_mode == "concat":
dec_output = tf.concat(encs_id, axis=-1)*msk_input
dec_output_id = tf.concat(encs_id, axis=-1)*msk_input
elif self.node_update_mode == "additive":
dec_output = X_enc_cg
dec_output_id = X_enc_cg

X_enc_cg = X_enc
if self.do_node_encoding:
X_enc_cg = X_enc_ffn

encs_reg = []
if self.skip_connection:
encs_reg.append(X_enc)

for cg in self.cg_reg:
enc_all = cg(X_enc_cg, msk, training=training)
if self.node_update_mode == "additive":
Expand All @@ -870,15 +874,15 @@ def call(self, inputs, training=False):
encs_reg.append(X_enc_cg)

if self.node_update_mode == "concat":
dec_output_energy = tf.concat(encs_reg, axis=-1)*msk_input
dec_output_reg = tf.concat(encs_reg, axis=-1)*msk_input
elif self.node_update_mode == "additive":
dec_output_energy = X_enc_cg
dec_output_reg = X_enc_cg

if self.debug:
debugging_data["dec_output"] = dec_output
debugging_data["dec_output_energy"] = dec_output_energy
debugging_data["dec_output_id"] = dec_output_id
debugging_data["dec_output_reg"] = dec_output_reg

ret = self.output_dec([X, dec_output, dec_output_energy, msk_input], training=training)
ret = self.output_dec([X, dec_output_id, dec_output_reg, msk_input], training=training)

if self.debug:
for k in debugging_data.keys():
Expand All @@ -899,17 +903,23 @@ def set_trainable_named(self, layer_names):

# Uncomment these if you want to explicitly debug the training loop
# def train_step(self, data):
# import numpy as np
# x, y, sample_weights = data
# if not hasattr(self, "step"):
# self.step = 0

# with tf.GradientTape() as tape:
# y_pred = self(x, training=True) # Forward pass
# loss = self.compiled_loss(y, y_pred, sample_weights, regularization_losses=self.losses)
# import pdb;pdb.set_trace()

# trainable_vars = self.trainable_variables
# gradients = tape.gradient(loss, trainable_vars)
# for tv, g in zip(trainable_vars, gradients):
# g = g.numpy()
# num_nan = np.sum(np.isnan(g))
# if num_nan>0:
# print(tv.name, num_nan, g.shape)

# self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# self.compiled_metrics.update_state(y, y_pred)

Expand Down
Loading