Skip to content

Commit

Permalink
Fix output decoding for PFNetDense (#107)
Browse files Browse the repository at this point in the history
* better corr plotting

* rename layers, fix output decoding

Former-commit-id: 2e3bfb6
  • Loading branch information
jpata authored May 3, 2022
1 parent d473280 commit b57922d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 63 deletions.
122 changes: 61 additions & 61 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,11 @@ def __init__(self, num_input_classes):
"""
X: [Nbatch, Nelem, Nfeat] array of all the input detector element feature data
"""
@tf.function
def call(self, X):

log_energy = tf.expand_dims(tf.math.log(X[:, :, 4]+1.0), axis=-1)

#X[:, :, 0] - categorical index of the element type
Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=X.dtype)
#Xpt = tf.expand_dims(tf.math.log1p(X[:, :, 1]), axis=-1)
Xpt = tf.expand_dims(tf.math.log(X[:, :, 1] + 1.0), axis=-1)
Xe = tf.expand_dims(tf.math.log(X[:, :, 4] + 1.0), axis=-1)

Xpt_0p5 = tf.math.sqrt(Xpt)
Xpt_2 = tf.math.pow(Xpt, 2)
Expand All @@ -154,15 +150,8 @@ def call(self, X):
Xphi1 = tf.expand_dims(tf.sin(X[:, :, 3]), axis=-1)
Xphi2 = tf.expand_dims(tf.cos(X[:, :, 3]), axis=-1)

#Xe = tf.expand_dims(tf.math.log1p(X[:, :, 4]), axis=-1)
Xe = log_energy
Xe_0p5 = tf.math.sqrt(log_energy)
Xe_2 = tf.math.pow(log_energy, 2)

Xe_transverse = log_energy - tf.math.log(Xeta2)

Xlayer = tf.expand_dims(X[:, :, 5]*10.0, axis=-1)
Xdepth = tf.expand_dims(X[:, :, 6]*10.0, axis=-1)
Xe_0p5 = tf.math.sqrt(Xe)
Xe_2 = tf.math.pow(Xe, 2)

Xphi_ecal1 = tf.expand_dims(tf.sin(X[:, :, 10]), axis=-1)
Xphi_ecal2 = tf.expand_dims(tf.cos(X[:, :, 10]), axis=-1)
Expand All @@ -176,8 +165,6 @@ def call(self, X):
Xabs_eta,
Xphi1, Xphi2,
Xe, Xe_0p5, Xe_2,
Xe_transverse,
Xlayer, Xdepth,
Xphi_ecal1, Xphi_ecal2,
Xphi_hcal1, Xphi_hcal2,
X], axis=-1
Expand All @@ -199,7 +186,11 @@ def build(self, input_shape):
self.W_h = self.add_weight(shape=(self.hidden_dim, self.output_dim), name="w_h", initializer="random_normal", trainable=True, regularizer=tf.keras.regularizers.L1(regularizer_weight))
self.theta = self.add_weight(shape=(self.hidden_dim, self.output_dim), name="theta", initializer="random_normal", trainable=True, regularizer=tf.keras.regularizers.L1(regularizer_weight))

#@tf.function
"""
x: [batches, bins, elements, features]
adj: [batches, bins, elements, elements]
msk: [batches, bins, elements]
"""
def call(self, inputs):
x, adj, msk = inputs

Expand Down Expand Up @@ -307,7 +298,8 @@ def __init__(self, clip_value_low=0.0, dist_mult=0.1, **kwargs):
returns: (n_batch, n_bins, n_points, n_points, 1) message matrix
"""
def call(self, x_msg_binned, msk, training=False):
dm = tf.expand_dims(pairwise_gaussian_dist(x_msg_binned*msk, x_msg_binned*msk), axis=-1)
x = x_msg_binned*msk
dm = tf.expand_dims(pairwise_gaussian_dist(x, x), axis=-1)
dm = tf.exp(-self.dist_mult*dm)
dm = tf.clip_by_value(dm, self.clip_value_low, 1)
return dm
Expand Down Expand Up @@ -349,7 +341,7 @@ def call(self, x_msg_binned, msk, training=False):

dm = pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training)
dm = self.activation(dm)
return dm
return tf.reduce_max(dm, axis=-1, keepdims=True)

def build_kernel_from_conf(kernel_dict, name):
kernel_dict = kernel_dict.copy()
Expand Down Expand Up @@ -566,16 +558,17 @@ def call(self, args, training=False):

out_id_logits = self.ffn_id(X_encoded, training=training)
out_id_logits = out_id_logits * tf.cast(msk_input, out_id_logits.dtype)
msk_input_outtype = tf.cast(msk_input, out_id_logits.dtype)

out_id_softmax = tf.nn.softmax(out_id_logits, axis=-1)
out_id_hard_softmax = tf.stop_gradient(tf.nn.softmax(100*out_id_logits, axis=-1))

out_charge = self.ffn_charge(X_encoded, training=training)
out_charge = out_charge * tf.cast(msk_input, out_charge.dtype)
out_charge = out_charge * msk_input_outtype

orig_eta = tf.cast(X_input[:, :, 2:3], out_id_logits.dtype)

#FIXME: better schema propagation
#FIXME: better schema propagation between hep_tfds
#skip connection from raw input values
if self.schema == "cms":
orig_sin_phi = tf.cast(tf.math.sin(X_input[:, :, 3:4])*msk_input, out_id_logits.dtype)
Expand All @@ -590,9 +583,9 @@ def call(self, args, training=False):
X_encoded = tf.concat([X_encoded, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1)

pred_eta_corr = self.ffn_eta(X_encoded, training=training)
pred_eta_corr = pred_eta_corr*tf.cast(msk_input, pred_eta_corr.dtype)
pred_eta_corr = pred_eta_corr*msk_input_outtype
pred_phi_corr = self.ffn_phi(X_encoded, training=training)
pred_phi_corr = pred_phi_corr*tf.cast(msk_input, pred_phi_corr.dtype)
pred_phi_corr = pred_phi_corr*msk_input_outtype

if self.eta_skip_gate:
eta_gate = tf.keras.activations.sigmoid(pred_eta_corr[:, :, 0:1])
Expand All @@ -614,7 +607,7 @@ def call(self, args, training=False):
X_encoded_energy = tf.concat([X_encoded_energy, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1)

pred_energy_corr = self.ffn_energy(X_encoded_energy, training=training)
pred_energy_corr = pred_energy_corr*tf.cast(msk_input, pred_energy_corr.dtype)
pred_energy_corr = pred_energy_corr*msk_input_outtype

#In case of a multimodal prediction, weight the per-class energy predictions by the approximately one-hot vector
if self.energy_multimodal:
Expand All @@ -626,7 +619,7 @@ def call(self, args, training=False):
orig_pt = tf.stop_gradient(pred_energy/tf.math.cosh(tf.clip_by_value(pred_eta, -8, 8)))

pred_pt_corr = self.ffn_pt(X_encoded_energy, training=training)
pred_pt_corr = pred_pt_corr*tf.cast(msk_input, pred_pt_corr.dtype)
pred_pt_corr = pred_pt_corr*msk_input_outtype

if self.pt_skip_gate:
pt_gate = tf.keras.activations.sigmoid(pred_pt_corr[:, :, 0:1])
Expand All @@ -647,16 +640,16 @@ def call(self, args, training=False):

ret = {
"cls": out_id_softmax,
"charge": out_charge*msk_input,
"pt": pred_pt*msk_input,
"eta": pred_eta*msk_input,
"sin_phi": pred_sin_phi*msk_input,
"cos_phi": pred_cos_phi*msk_input,
"energy": pred_energy*msk_input,
"charge": out_charge*msk_input_outtype,
"pt": pred_pt*msk_input_outtype,
"eta": pred_eta*msk_input_outtype,
"sin_phi": pred_sin_phi*msk_input_outtype,
"cos_phi": pred_cos_phi*msk_input_outtype,
"energy": pred_energy*msk_input_outtype,

#per-event sum of energy and pt
"sum_energy": tf.reduce_sum(pred_energy*msk_input*msk_output, axis=-2),
"sum_pt": tf.reduce_sum(pred_pt*msk_input*msk_output, axis=-2),
"sum_energy": tf.reduce_sum(pred_energy*msk_input_outtype*msk_output, axis=-2),
"sum_pt": tf.reduce_sum(pred_pt*msk_input_outtype*msk_output, axis=-2),
}

return ret
Expand Down Expand Up @@ -695,7 +688,8 @@ def __init__(self, *args, **kwargs):
self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation", "linear"))

if self.do_layernorm:
self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm")
self.layernorm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm2")

#self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
self.ffn_dist = point_wise_feed_forward_network(
Expand Down Expand Up @@ -730,10 +724,12 @@ def __init__(self, *args, **kwargs):
def call(self, x, msk, training=False):

if self.do_layernorm:
x = self.layernorm(x, training=training)
x = self.layernorm1(x, training=training)

#compute node features for graph building
x_dist = self.dist_activation(self.ffn_dist(x, training=training))
if self.do_layernorm:
x_dist = self.layernorm2(x_dist, training=training)

#compute the element-to-element messages / distance matrix / graph structure
if self.do_lsh:
Expand All @@ -752,9 +748,7 @@ def call(self, x, msk, training=False):

#run the node update with message passing
for msg in self.message_passing_layers:

x = msg((x, dm, msk_f))

if self.dropout_layer:
x = self.dropout_layer(x, training=training)

Expand Down Expand Up @@ -814,8 +808,8 @@ def __init__(self,
elif input_encoding == "default":
self.enc = InputEncoding(num_input_classes)

self.cg = [CombinedGraphLayer(name="cg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)]
self.cg_energy = [CombinedGraphLayer(name="cg_energy_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)]
self.cg_id = [CombinedGraphLayer(name="cg_id_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)]
self.cg_reg = [CombinedGraphLayer(name="cg_reg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)]

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
Expand All @@ -832,59 +826,59 @@ def call(self, inputs, training=False):
msk = X[:, :, 0] != 0
msk_input = tf.expand_dims(tf.cast(msk, X_enc.dtype), -1)

encs = []
encs_id = []
if self.skip_connection:
encs.append(X_enc)
encs_id.append(X_enc)

X_enc_cg = X_enc
if self.do_node_encoding:
X_enc_ffn = self.activation(self.node_encoding(X_enc_cg, training=training))
X_enc_cg = X_enc_ffn

for cg in self.cg:
for cg in self.cg_id:
enc_all = cg(X_enc_cg, msk, training=training)

if self.node_update_mode == "additive":
X_enc_cg += enc_all["enc"]
elif self.node_update_mode == "concat":
X_enc_cg = enc_all["enc"]
encs.append(X_enc_cg)
encs_id.append(X_enc_cg)

if self.debug:
debugging_data[cg.name] = enc_all

if self.node_update_mode == "concat":
dec_output = tf.concat(encs, axis=-1)*msk_input
dec_output = tf.concat(encs_id, axis=-1)*msk_input
elif self.node_update_mode == "additive":
dec_output = X_enc_cg

X_enc_cg = X_enc
if self.do_node_encoding:
X_enc_cg = X_enc_ffn

encs_energy = []
for cg in self.cg_energy:
encs_reg = []
for cg in self.cg_reg:
enc_all = cg(X_enc_cg, msk, training=training)
if self.node_update_mode == "additive":
X_enc_cg += enc_all["enc"]
elif self.node_update_mode == "concat":
X_enc_cg = enc_all["enc"]
encs_energy.append(X_enc_cg)
encs_reg.append(X_enc_cg)

if self.debug:
debugging_data[cg.name] = enc_all
encs_energy.append(X_enc_cg)
encs_reg.append(X_enc_cg)

if self.node_update_mode == "concat":
dec_output_energy = tf.concat(encs_energy, axis=-1)*msk_input
dec_output_energy = tf.concat(encs_reg, axis=-1)*msk_input
elif self.node_update_mode == "additive":
dec_output_energy = X_enc_cg

if self.debug:
debugging_data["dec_output"] = dec_output
debugging_data["dec_output_energy"] = dec_output_energy

ret = self.output_dec([X_enc, dec_output, dec_output_energy, msk_input], training=training)
ret = self.output_dec([X, dec_output, dec_output_energy, msk_input], training=training)

if self.debug:
for k in debugging_data.keys():
Expand Down Expand Up @@ -955,10 +949,14 @@ def __init__(self, *args, **kwargs):

def call(self, args, training=False):
X, mask = args
attn_output = self.attn(query=X, value=X, key=X, training=training)
msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1)

attn_output = self.attn(query=X, value=X, key=X, training=training, attention_mask=mask)*msk_input
out1 = self.norm1(X + attn_output)

ffn_output = self.ffn(out1)
out2 = self.norm2(out1 + ffn_output)

return out2

class KernelDecoder(tf.keras.layers.Layer):
Expand All @@ -979,11 +977,12 @@ def __init__(self, *args, **kwargs):

def call(self, args, training=False):
X, enc_output, mask = args
msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1)

attn1 = self.attn1(query=X, value=X, key=X, training=training)
attn1 = self.attn1(query=X, value=X, key=X, training=training, attention_mask=mask)*msk_input
out1 = self.norm1(attn1 + X, training=training)

attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training)
attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training, attention_mask=mask)*msk_input
out2 = self.norm2(attn2 + out1)

ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
Expand All @@ -1007,14 +1006,15 @@ def __init__(self, *args, **kwargs):
super(Transformer, self).__init__(*args, **kwargs)

def call(self, inputs, training=False):
X, msk_input = inputs
X, mask = inputs
msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1)

for enc in self.encoders:
X = enc([X, msk_input], training=training)*msk_input
X = enc([X, mask], training=training)*msk_input

X_dec = X
for dec in self.decoders:
X_dec = dec([X_dec, X, msk_input], training=training)*msk_input
X_dec = dec([X_dec, X, mask], training=training)*msk_input

return X_dec

Expand All @@ -1039,8 +1039,8 @@ def __init__(self,
key_dim = 64
self.ffn = point_wise_feed_forward_network(key_dim, key_dim, "ffn", num_layers=1, activation="elu")

self.tf1 = Transformer(key_dim=key_dim, num_layers=2, name="tf1")
self.tf2 = Transformer(key_dim=key_dim, num_layers=2, name="tf2")
self.tf1 = Transformer(key_dim=key_dim, num_layers=5, name="tf1")
self.tf2 = Transformer(key_dim=key_dim, num_layers=5, name="tf2")

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
Expand All @@ -1051,14 +1051,14 @@ def call(self, inputs, training=False):
debugging_data = {}

#mask padded elements
msk = X[:, :, 0] != 0
msk = tf.cast(X[:, :, 0] != 0, tf.float32)
msk_input = tf.expand_dims(tf.cast(msk, tf.float32), -1)

X_enc = self.enc(X)
X_enc = self.ffn(X_enc)

X_enc_1 = self.tf1([X_enc, msk_input], training=training)
X_enc_2 = self.tf2([X_enc, msk_input], training=training)
X_enc_1 = self.tf1([X_enc, msk], training=training)
X_enc_2 = self.tf2([X_enc, msk], training=training)

ret = self.output_dec([X, X_enc_1, X_enc_2, msk_input], training=training)

Expand Down
16 changes: 14 additions & 2 deletions mlpf/tfmodel/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def __init__(self, outpath, dataset, dataset_info, plot_freq=1, comet_experiment
}

self.reg_bins = {
"pt": np.linspace(-100, 1000, 100),
"pt": np.linspace(-100, 200, 100),
"eta": np.linspace(-6, 6, 100),
"sin_phi": np.linspace(-1,1,100),
"cos_phi": np.linspace(-1,1,100),
"energy": np.linspace(-100, 5000, 100),
"energy": np.linspace(-100, 1000, 100),
}

def plot_cm(self, epoch, outpath, ypred_id, msk):
Expand Down Expand Up @@ -268,9 +268,21 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable):
#save scatterplot of raw values
plt.figure(figsize=(6,5))
bins = self.reg_bins[reg_variable]

if bins is None:
bins = 100

if reg_variable == "pt" or reg_variable == "energy":
bins = np.logspace(-2,3,100)
vals_true = np.log10(vals_true)
vals_pred = np.log10(vals_pred)
vals_true[np.isnan(vals_true)] = 0.0
vals_pred[np.isnan(vals_pred)] = 0.0

plt.hist2d(vals_true, vals_pred, bins=(bins, bins), cmin=1, cmap="Blues", norm=matplotlib.colors.LogNorm())
if reg_variable == "pt" or reg_variable == "energy":
plt.xscale("log")
plt.yscale("log")
plt.colorbar()

if len(vals_true) > 0:
Expand Down

0 comments on commit b57922d

Please sign in to comment.