From 4bff2d7ec6d500a1d61d473b85d626b02ebcd3e1 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 3 May 2022 14:02:05 +0300 Subject: [PATCH 1/2] better corr plotting --- mlpf/tfmodel/model_setup.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index 4a11af1b3..704171c74 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -111,11 +111,11 @@ def __init__(self, outpath, dataset, dataset_info, plot_freq=1, comet_experiment } self.reg_bins = { - "pt": np.linspace(-100, 1000, 100), + "pt": np.linspace(-100, 200, 100), "eta": np.linspace(-6, 6, 100), "sin_phi": np.linspace(-1,1,100), "cos_phi": np.linspace(-1,1,100), - "energy": np.linspace(-100, 5000, 100), + "energy": np.linspace(-100, 1000, 100), } def plot_cm(self, epoch, outpath, ypred_id, msk): @@ -268,9 +268,21 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable): #save scatterplot of raw values plt.figure(figsize=(6,5)) bins = self.reg_bins[reg_variable] + if bins is None: bins = 100 + + if reg_variable == "pt" or reg_variable == "energy": + bins = np.logspace(-2,3,100) + vals_true = np.log10(vals_true) + vals_pred = np.log10(vals_pred) + vals_true[np.isnan(vals_true)] = 0.0 + vals_pred[np.isnan(vals_pred)] = 0.0 + plt.hist2d(vals_true, vals_pred, bins=(bins, bins), cmin=1, cmap="Blues", norm=matplotlib.colors.LogNorm()) + if reg_variable == "pt" or reg_variable == "energy": + plt.xscale("log") + plt.yscale("log") plt.colorbar() if len(vals_true) > 0: From b3dff062cd4e622e21a0329fec94c8e716a39b9f Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 3 May 2022 14:02:16 +0300 Subject: [PATCH 2/2] rename layers, fix output decoding --- mlpf/tfmodel/model.py | 122 +++++++++++++++++++++--------------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index 99f1550d8..1176caeff 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -135,15 +135,11 @@ def __init__(self, num_input_classes): """ X: [Nbatch, Nelem, Nfeat] array of all the input detector element feature data """ - @tf.function def call(self, X): - - log_energy = tf.expand_dims(tf.math.log(X[:, :, 4]+1.0), axis=-1) - #X[:, :, 0] - categorical index of the element type Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=X.dtype) - #Xpt = tf.expand_dims(tf.math.log1p(X[:, :, 1]), axis=-1) Xpt = tf.expand_dims(tf.math.log(X[:, :, 1] + 1.0), axis=-1) + Xe = tf.expand_dims(tf.math.log(X[:, :, 4] + 1.0), axis=-1) Xpt_0p5 = tf.math.sqrt(Xpt) Xpt_2 = tf.math.pow(Xpt, 2) @@ -154,15 +150,8 @@ def call(self, X): Xphi1 = tf.expand_dims(tf.sin(X[:, :, 3]), axis=-1) Xphi2 = tf.expand_dims(tf.cos(X[:, :, 3]), axis=-1) - #Xe = tf.expand_dims(tf.math.log1p(X[:, :, 4]), axis=-1) - Xe = log_energy - Xe_0p5 = tf.math.sqrt(log_energy) - Xe_2 = tf.math.pow(log_energy, 2) - - Xe_transverse = log_energy - tf.math.log(Xeta2) - - Xlayer = tf.expand_dims(X[:, :, 5]*10.0, axis=-1) - Xdepth = tf.expand_dims(X[:, :, 6]*10.0, axis=-1) + Xe_0p5 = tf.math.sqrt(Xe) + Xe_2 = tf.math.pow(Xe, 2) Xphi_ecal1 = tf.expand_dims(tf.sin(X[:, :, 10]), axis=-1) Xphi_ecal2 = tf.expand_dims(tf.cos(X[:, :, 10]), axis=-1) @@ -176,8 +165,6 @@ def call(self, X): Xabs_eta, Xphi1, Xphi2, Xe, Xe_0p5, Xe_2, - Xe_transverse, - Xlayer, Xdepth, Xphi_ecal1, Xphi_ecal2, Xphi_hcal1, Xphi_hcal2, X], axis=-1 @@ -199,7 +186,11 @@ def build(self, input_shape): self.W_h = self.add_weight(shape=(self.hidden_dim, self.output_dim), name="w_h", initializer="random_normal", trainable=True, regularizer=tf.keras.regularizers.L1(regularizer_weight)) self.theta = self.add_weight(shape=(self.hidden_dim, self.output_dim), name="theta", initializer="random_normal", trainable=True, regularizer=tf.keras.regularizers.L1(regularizer_weight)) - #@tf.function + """ + x: [batches, bins, elements, features] + adj: [batches, bins, elements, elements] + msk: [batches, bins, elements] + """ def call(self, inputs): x, adj, msk = inputs @@ -307,7 +298,8 @@ def __init__(self, clip_value_low=0.0, dist_mult=0.1, **kwargs): returns: (n_batch, n_bins, n_points, n_points, 1) message matrix """ def call(self, x_msg_binned, msk, training=False): - dm = tf.expand_dims(pairwise_gaussian_dist(x_msg_binned*msk, x_msg_binned*msk), axis=-1) + x = x_msg_binned*msk + dm = tf.expand_dims(pairwise_gaussian_dist(x, x), axis=-1) dm = tf.exp(-self.dist_mult*dm) dm = tf.clip_by_value(dm, self.clip_value_low, 1) return dm @@ -349,7 +341,7 @@ def call(self, x_msg_binned, msk, training=False): dm = pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training) dm = self.activation(dm) - return dm + return tf.reduce_max(dm, axis=-1, keepdims=True) def build_kernel_from_conf(kernel_dict, name): kernel_dict = kernel_dict.copy() @@ -566,16 +558,17 @@ def call(self, args, training=False): out_id_logits = self.ffn_id(X_encoded, training=training) out_id_logits = out_id_logits * tf.cast(msk_input, out_id_logits.dtype) + msk_input_outtype = tf.cast(msk_input, out_id_logits.dtype) out_id_softmax = tf.nn.softmax(out_id_logits, axis=-1) out_id_hard_softmax = tf.stop_gradient(tf.nn.softmax(100*out_id_logits, axis=-1)) out_charge = self.ffn_charge(X_encoded, training=training) - out_charge = out_charge * tf.cast(msk_input, out_charge.dtype) + out_charge = out_charge * msk_input_outtype orig_eta = tf.cast(X_input[:, :, 2:3], out_id_logits.dtype) - #FIXME: better schema propagation + #FIXME: better schema propagation between hep_tfds #skip connection from raw input values if self.schema == "cms": orig_sin_phi = tf.cast(tf.math.sin(X_input[:, :, 3:4])*msk_input, out_id_logits.dtype) @@ -590,9 +583,9 @@ def call(self, args, training=False): X_encoded = tf.concat([X_encoded, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1) pred_eta_corr = self.ffn_eta(X_encoded, training=training) - pred_eta_corr = pred_eta_corr*tf.cast(msk_input, pred_eta_corr.dtype) + pred_eta_corr = pred_eta_corr*msk_input_outtype pred_phi_corr = self.ffn_phi(X_encoded, training=training) - pred_phi_corr = pred_phi_corr*tf.cast(msk_input, pred_phi_corr.dtype) + pred_phi_corr = pred_phi_corr*msk_input_outtype if self.eta_skip_gate: eta_gate = tf.keras.activations.sigmoid(pred_eta_corr[:, :, 0:1]) @@ -614,7 +607,7 @@ def call(self, args, training=False): X_encoded_energy = tf.concat([X_encoded_energy, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1) pred_energy_corr = self.ffn_energy(X_encoded_energy, training=training) - pred_energy_corr = pred_energy_corr*tf.cast(msk_input, pred_energy_corr.dtype) + pred_energy_corr = pred_energy_corr*msk_input_outtype #In case of a multimodal prediction, weight the per-class energy predictions by the approximately one-hot vector if self.energy_multimodal: @@ -626,7 +619,7 @@ def call(self, args, training=False): orig_pt = tf.stop_gradient(pred_energy/tf.math.cosh(tf.clip_by_value(pred_eta, -8, 8))) pred_pt_corr = self.ffn_pt(X_encoded_energy, training=training) - pred_pt_corr = pred_pt_corr*tf.cast(msk_input, pred_pt_corr.dtype) + pred_pt_corr = pred_pt_corr*msk_input_outtype if self.pt_skip_gate: pt_gate = tf.keras.activations.sigmoid(pred_pt_corr[:, :, 0:1]) @@ -647,16 +640,16 @@ def call(self, args, training=False): ret = { "cls": out_id_softmax, - "charge": out_charge*msk_input, - "pt": pred_pt*msk_input, - "eta": pred_eta*msk_input, - "sin_phi": pred_sin_phi*msk_input, - "cos_phi": pred_cos_phi*msk_input, - "energy": pred_energy*msk_input, + "charge": out_charge*msk_input_outtype, + "pt": pred_pt*msk_input_outtype, + "eta": pred_eta*msk_input_outtype, + "sin_phi": pred_sin_phi*msk_input_outtype, + "cos_phi": pred_cos_phi*msk_input_outtype, + "energy": pred_energy*msk_input_outtype, #per-event sum of energy and pt - "sum_energy": tf.reduce_sum(pred_energy*msk_input*msk_output, axis=-2), - "sum_pt": tf.reduce_sum(pred_pt*msk_input*msk_output, axis=-2), + "sum_energy": tf.reduce_sum(pred_energy*msk_input_outtype*msk_output, axis=-2), + "sum_pt": tf.reduce_sum(pred_pt*msk_input_outtype*msk_output, axis=-2), } return ret @@ -695,7 +688,8 @@ def __init__(self, *args, **kwargs): self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation", "linear")) if self.do_layernorm: - self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm") + self.layernorm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm1") + self.layernorm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm2") #self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01) self.ffn_dist = point_wise_feed_forward_network( @@ -730,10 +724,12 @@ def __init__(self, *args, **kwargs): def call(self, x, msk, training=False): if self.do_layernorm: - x = self.layernorm(x, training=training) + x = self.layernorm1(x, training=training) #compute node features for graph building x_dist = self.dist_activation(self.ffn_dist(x, training=training)) + if self.do_layernorm: + x_dist = self.layernorm2(x_dist, training=training) #compute the element-to-element messages / distance matrix / graph structure if self.do_lsh: @@ -752,9 +748,7 @@ def call(self, x, msk, training=False): #run the node update with message passing for msg in self.message_passing_layers: - x = msg((x, dm, msk_f)) - if self.dropout_layer: x = self.dropout_layer(x, training=training) @@ -814,8 +808,8 @@ def __init__(self, elif input_encoding == "default": self.enc = InputEncoding(num_input_classes) - self.cg = [CombinedGraphLayer(name="cg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)] - self.cg_energy = [CombinedGraphLayer(name="cg_energy_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)] + self.cg_id = [CombinedGraphLayer(name="cg_id_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)] + self.cg_reg = [CombinedGraphLayer(name="cg_reg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)] output_decoding["schema"] = schema output_decoding["num_output_classes"] = num_output_classes @@ -832,29 +826,29 @@ def call(self, inputs, training=False): msk = X[:, :, 0] != 0 msk_input = tf.expand_dims(tf.cast(msk, X_enc.dtype), -1) - encs = [] + encs_id = [] if self.skip_connection: - encs.append(X_enc) + encs_id.append(X_enc) X_enc_cg = X_enc if self.do_node_encoding: X_enc_ffn = self.activation(self.node_encoding(X_enc_cg, training=training)) X_enc_cg = X_enc_ffn - for cg in self.cg: + for cg in self.cg_id: enc_all = cg(X_enc_cg, msk, training=training) if self.node_update_mode == "additive": X_enc_cg += enc_all["enc"] elif self.node_update_mode == "concat": X_enc_cg = enc_all["enc"] - encs.append(X_enc_cg) + encs_id.append(X_enc_cg) if self.debug: debugging_data[cg.name] = enc_all if self.node_update_mode == "concat": - dec_output = tf.concat(encs, axis=-1)*msk_input + dec_output = tf.concat(encs_id, axis=-1)*msk_input elif self.node_update_mode == "additive": dec_output = X_enc_cg @@ -862,21 +856,21 @@ def call(self, inputs, training=False): if self.do_node_encoding: X_enc_cg = X_enc_ffn - encs_energy = [] - for cg in self.cg_energy: + encs_reg = [] + for cg in self.cg_reg: enc_all = cg(X_enc_cg, msk, training=training) if self.node_update_mode == "additive": X_enc_cg += enc_all["enc"] elif self.node_update_mode == "concat": X_enc_cg = enc_all["enc"] - encs_energy.append(X_enc_cg) + encs_reg.append(X_enc_cg) if self.debug: debugging_data[cg.name] = enc_all - encs_energy.append(X_enc_cg) + encs_reg.append(X_enc_cg) if self.node_update_mode == "concat": - dec_output_energy = tf.concat(encs_energy, axis=-1)*msk_input + dec_output_energy = tf.concat(encs_reg, axis=-1)*msk_input elif self.node_update_mode == "additive": dec_output_energy = X_enc_cg @@ -884,7 +878,7 @@ def call(self, inputs, training=False): debugging_data["dec_output"] = dec_output debugging_data["dec_output_energy"] = dec_output_energy - ret = self.output_dec([X_enc, dec_output, dec_output_energy, msk_input], training=training) + ret = self.output_dec([X, dec_output, dec_output_energy, msk_input], training=training) if self.debug: for k in debugging_data.keys(): @@ -955,10 +949,14 @@ def __init__(self, *args, **kwargs): def call(self, args, training=False): X, mask = args - attn_output = self.attn(query=X, value=X, key=X, training=training) + msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1) + + attn_output = self.attn(query=X, value=X, key=X, training=training, attention_mask=mask)*msk_input out1 = self.norm1(X + attn_output) + ffn_output = self.ffn(out1) out2 = self.norm2(out1 + ffn_output) + return out2 class KernelDecoder(tf.keras.layers.Layer): @@ -979,11 +977,12 @@ def __init__(self, *args, **kwargs): def call(self, args, training=False): X, enc_output, mask = args + msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1) - attn1 = self.attn1(query=X, value=X, key=X, training=training) + attn1 = self.attn1(query=X, value=X, key=X, training=training, attention_mask=mask)*msk_input out1 = self.norm1(attn1 + X, training=training) - attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training) + attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training, attention_mask=mask)*msk_input out2 = self.norm2(attn2 + out1) ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model) @@ -1007,14 +1006,15 @@ def __init__(self, *args, **kwargs): super(Transformer, self).__init__(*args, **kwargs) def call(self, inputs, training=False): - X, msk_input = inputs + X, mask = inputs + msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1) for enc in self.encoders: - X = enc([X, msk_input], training=training)*msk_input + X = enc([X, mask], training=training)*msk_input X_dec = X for dec in self.decoders: - X_dec = dec([X_dec, X, msk_input], training=training)*msk_input + X_dec = dec([X_dec, X, mask], training=training)*msk_input return X_dec @@ -1039,8 +1039,8 @@ def __init__(self, key_dim = 64 self.ffn = point_wise_feed_forward_network(key_dim, key_dim, "ffn", num_layers=1, activation="elu") - self.tf1 = Transformer(key_dim=key_dim, num_layers=2, name="tf1") - self.tf2 = Transformer(key_dim=key_dim, num_layers=2, name="tf2") + self.tf1 = Transformer(key_dim=key_dim, num_layers=5, name="tf1") + self.tf2 = Transformer(key_dim=key_dim, num_layers=5, name="tf2") output_decoding["schema"] = schema output_decoding["num_output_classes"] = num_output_classes @@ -1051,14 +1051,14 @@ def call(self, inputs, training=False): debugging_data = {} #mask padded elements - msk = X[:, :, 0] != 0 + msk = tf.cast(X[:, :, 0] != 0, tf.float32) msk_input = tf.expand_dims(tf.cast(msk, tf.float32), -1) X_enc = self.enc(X) X_enc = self.ffn(X_enc) - X_enc_1 = self.tf1([X_enc, msk_input], training=training) - X_enc_2 = self.tf2([X_enc, msk_input], training=training) + X_enc_1 = self.tf1([X_enc, msk], training=training) + X_enc_2 = self.tf2([X_enc, msk], training=training) ret = self.output_dec([X, X_enc_1, X_enc_2, msk_input], training=training)