Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix output decoding for PFNetDense #107

Merged
merged 2 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 61 additions & 61 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,11 @@ def __init__(self, num_input_classes):
"""
X: [Nbatch, Nelem, Nfeat] array of all the input detector element feature data
"""
@tf.function
def call(self, X):

log_energy = tf.expand_dims(tf.math.log(X[:, :, 4]+1.0), axis=-1)

#X[:, :, 0] - categorical index of the element type
Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=X.dtype)
#Xpt = tf.expand_dims(tf.math.log1p(X[:, :, 1]), axis=-1)
Xpt = tf.expand_dims(tf.math.log(X[:, :, 1] + 1.0), axis=-1)
Xe = tf.expand_dims(tf.math.log(X[:, :, 4] + 1.0), axis=-1)

Xpt_0p5 = tf.math.sqrt(Xpt)
Xpt_2 = tf.math.pow(Xpt, 2)
Expand All @@ -154,15 +150,8 @@ def call(self, X):
Xphi1 = tf.expand_dims(tf.sin(X[:, :, 3]), axis=-1)
Xphi2 = tf.expand_dims(tf.cos(X[:, :, 3]), axis=-1)

#Xe = tf.expand_dims(tf.math.log1p(X[:, :, 4]), axis=-1)
Xe = log_energy
Xe_0p5 = tf.math.sqrt(log_energy)
Xe_2 = tf.math.pow(log_energy, 2)

Xe_transverse = log_energy - tf.math.log(Xeta2)

Xlayer = tf.expand_dims(X[:, :, 5]*10.0, axis=-1)
Xdepth = tf.expand_dims(X[:, :, 6]*10.0, axis=-1)
Xe_0p5 = tf.math.sqrt(Xe)
Xe_2 = tf.math.pow(Xe, 2)

Xphi_ecal1 = tf.expand_dims(tf.sin(X[:, :, 10]), axis=-1)
Xphi_ecal2 = tf.expand_dims(tf.cos(X[:, :, 10]), axis=-1)
Expand All @@ -176,8 +165,6 @@ def call(self, X):
Xabs_eta,
Xphi1, Xphi2,
Xe, Xe_0p5, Xe_2,
Xe_transverse,
Xlayer, Xdepth,
Xphi_ecal1, Xphi_ecal2,
Xphi_hcal1, Xphi_hcal2,
X], axis=-1
Expand All @@ -199,7 +186,11 @@ def build(self, input_shape):
self.W_h = self.add_weight(shape=(self.hidden_dim, self.output_dim), name="w_h", initializer="random_normal", trainable=True, regularizer=tf.keras.regularizers.L1(regularizer_weight))
self.theta = self.add_weight(shape=(self.hidden_dim, self.output_dim), name="theta", initializer="random_normal", trainable=True, regularizer=tf.keras.regularizers.L1(regularizer_weight))

#@tf.function
"""
x: [batches, bins, elements, features]
adj: [batches, bins, elements, elements]
msk: [batches, bins, elements]
"""
def call(self, inputs):
x, adj, msk = inputs

Expand Down Expand Up @@ -307,7 +298,8 @@ def __init__(self, clip_value_low=0.0, dist_mult=0.1, **kwargs):
returns: (n_batch, n_bins, n_points, n_points, 1) message matrix
"""
def call(self, x_msg_binned, msk, training=False):
dm = tf.expand_dims(pairwise_gaussian_dist(x_msg_binned*msk, x_msg_binned*msk), axis=-1)
x = x_msg_binned*msk
dm = tf.expand_dims(pairwise_gaussian_dist(x, x), axis=-1)
dm = tf.exp(-self.dist_mult*dm)
dm = tf.clip_by_value(dm, self.clip_value_low, 1)
return dm
Expand Down Expand Up @@ -349,7 +341,7 @@ def call(self, x_msg_binned, msk, training=False):

dm = pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training)
dm = self.activation(dm)
return dm
return tf.reduce_max(dm, axis=-1, keepdims=True)

def build_kernel_from_conf(kernel_dict, name):
kernel_dict = kernel_dict.copy()
Expand Down Expand Up @@ -566,16 +558,17 @@ def call(self, args, training=False):

out_id_logits = self.ffn_id(X_encoded, training=training)
out_id_logits = out_id_logits * tf.cast(msk_input, out_id_logits.dtype)
msk_input_outtype = tf.cast(msk_input, out_id_logits.dtype)

out_id_softmax = tf.nn.softmax(out_id_logits, axis=-1)
out_id_hard_softmax = tf.stop_gradient(tf.nn.softmax(100*out_id_logits, axis=-1))

out_charge = self.ffn_charge(X_encoded, training=training)
out_charge = out_charge * tf.cast(msk_input, out_charge.dtype)
out_charge = out_charge * msk_input_outtype

orig_eta = tf.cast(X_input[:, :, 2:3], out_id_logits.dtype)

#FIXME: better schema propagation
#FIXME: better schema propagation between hep_tfds
#skip connection from raw input values
if self.schema == "cms":
orig_sin_phi = tf.cast(tf.math.sin(X_input[:, :, 3:4])*msk_input, out_id_logits.dtype)
Expand All @@ -590,9 +583,9 @@ def call(self, args, training=False):
X_encoded = tf.concat([X_encoded, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1)

pred_eta_corr = self.ffn_eta(X_encoded, training=training)
pred_eta_corr = pred_eta_corr*tf.cast(msk_input, pred_eta_corr.dtype)
pred_eta_corr = pred_eta_corr*msk_input_outtype
pred_phi_corr = self.ffn_phi(X_encoded, training=training)
pred_phi_corr = pred_phi_corr*tf.cast(msk_input, pred_phi_corr.dtype)
pred_phi_corr = pred_phi_corr*msk_input_outtype

if self.eta_skip_gate:
eta_gate = tf.keras.activations.sigmoid(pred_eta_corr[:, :, 0:1])
Expand All @@ -614,7 +607,7 @@ def call(self, args, training=False):
X_encoded_energy = tf.concat([X_encoded_energy, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1)

pred_energy_corr = self.ffn_energy(X_encoded_energy, training=training)
pred_energy_corr = pred_energy_corr*tf.cast(msk_input, pred_energy_corr.dtype)
pred_energy_corr = pred_energy_corr*msk_input_outtype

#In case of a multimodal prediction, weight the per-class energy predictions by the approximately one-hot vector
if self.energy_multimodal:
Expand All @@ -626,7 +619,7 @@ def call(self, args, training=False):
orig_pt = tf.stop_gradient(pred_energy/tf.math.cosh(tf.clip_by_value(pred_eta, -8, 8)))

pred_pt_corr = self.ffn_pt(X_encoded_energy, training=training)
pred_pt_corr = pred_pt_corr*tf.cast(msk_input, pred_pt_corr.dtype)
pred_pt_corr = pred_pt_corr*msk_input_outtype

if self.pt_skip_gate:
pt_gate = tf.keras.activations.sigmoid(pred_pt_corr[:, :, 0:1])
Expand All @@ -647,16 +640,16 @@ def call(self, args, training=False):

ret = {
"cls": out_id_softmax,
"charge": out_charge*msk_input,
"pt": pred_pt*msk_input,
"eta": pred_eta*msk_input,
"sin_phi": pred_sin_phi*msk_input,
"cos_phi": pred_cos_phi*msk_input,
"energy": pred_energy*msk_input,
"charge": out_charge*msk_input_outtype,
"pt": pred_pt*msk_input_outtype,
"eta": pred_eta*msk_input_outtype,
"sin_phi": pred_sin_phi*msk_input_outtype,
"cos_phi": pred_cos_phi*msk_input_outtype,
"energy": pred_energy*msk_input_outtype,

#per-event sum of energy and pt
"sum_energy": tf.reduce_sum(pred_energy*msk_input*msk_output, axis=-2),
"sum_pt": tf.reduce_sum(pred_pt*msk_input*msk_output, axis=-2),
"sum_energy": tf.reduce_sum(pred_energy*msk_input_outtype*msk_output, axis=-2),
"sum_pt": tf.reduce_sum(pred_pt*msk_input_outtype*msk_output, axis=-2),
}

return ret
Expand Down Expand Up @@ -695,7 +688,8 @@ def __init__(self, *args, **kwargs):
self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation", "linear"))

if self.do_layernorm:
self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm")
self.layernorm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm2")

#self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
self.ffn_dist = point_wise_feed_forward_network(
Expand Down Expand Up @@ -730,10 +724,12 @@ def __init__(self, *args, **kwargs):
def call(self, x, msk, training=False):

if self.do_layernorm:
x = self.layernorm(x, training=training)
x = self.layernorm1(x, training=training)

#compute node features for graph building
x_dist = self.dist_activation(self.ffn_dist(x, training=training))
if self.do_layernorm:
x_dist = self.layernorm2(x_dist, training=training)

#compute the element-to-element messages / distance matrix / graph structure
if self.do_lsh:
Expand All @@ -752,9 +748,7 @@ def call(self, x, msk, training=False):

#run the node update with message passing
for msg in self.message_passing_layers:

x = msg((x, dm, msk_f))

if self.dropout_layer:
x = self.dropout_layer(x, training=training)

Expand Down Expand Up @@ -814,8 +808,8 @@ def __init__(self,
elif input_encoding == "default":
self.enc = InputEncoding(num_input_classes)

self.cg = [CombinedGraphLayer(name="cg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)]
self.cg_energy = [CombinedGraphLayer(name="cg_energy_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)]
self.cg_id = [CombinedGraphLayer(name="cg_id_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_common)]
self.cg_reg = [CombinedGraphLayer(name="cg_reg_{}".format(i), **combined_graph_layer) for i in range(num_graph_layers_energy)]

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
Expand All @@ -832,59 +826,59 @@ def call(self, inputs, training=False):
msk = X[:, :, 0] != 0
msk_input = tf.expand_dims(tf.cast(msk, X_enc.dtype), -1)

encs = []
encs_id = []
if self.skip_connection:
encs.append(X_enc)
encs_id.append(X_enc)

X_enc_cg = X_enc
if self.do_node_encoding:
X_enc_ffn = self.activation(self.node_encoding(X_enc_cg, training=training))
X_enc_cg = X_enc_ffn

for cg in self.cg:
for cg in self.cg_id:
enc_all = cg(X_enc_cg, msk, training=training)

if self.node_update_mode == "additive":
X_enc_cg += enc_all["enc"]
elif self.node_update_mode == "concat":
X_enc_cg = enc_all["enc"]
encs.append(X_enc_cg)
encs_id.append(X_enc_cg)

if self.debug:
debugging_data[cg.name] = enc_all

if self.node_update_mode == "concat":
dec_output = tf.concat(encs, axis=-1)*msk_input
dec_output = tf.concat(encs_id, axis=-1)*msk_input
elif self.node_update_mode == "additive":
dec_output = X_enc_cg

X_enc_cg = X_enc
if self.do_node_encoding:
X_enc_cg = X_enc_ffn

encs_energy = []
for cg in self.cg_energy:
encs_reg = []
for cg in self.cg_reg:
enc_all = cg(X_enc_cg, msk, training=training)
if self.node_update_mode == "additive":
X_enc_cg += enc_all["enc"]
elif self.node_update_mode == "concat":
X_enc_cg = enc_all["enc"]
encs_energy.append(X_enc_cg)
encs_reg.append(X_enc_cg)

if self.debug:
debugging_data[cg.name] = enc_all
encs_energy.append(X_enc_cg)
encs_reg.append(X_enc_cg)

if self.node_update_mode == "concat":
dec_output_energy = tf.concat(encs_energy, axis=-1)*msk_input
dec_output_energy = tf.concat(encs_reg, axis=-1)*msk_input
elif self.node_update_mode == "additive":
dec_output_energy = X_enc_cg

if self.debug:
debugging_data["dec_output"] = dec_output
debugging_data["dec_output_energy"] = dec_output_energy

ret = self.output_dec([X_enc, dec_output, dec_output_energy, msk_input], training=training)
ret = self.output_dec([X, dec_output, dec_output_energy, msk_input], training=training)

if self.debug:
for k in debugging_data.keys():
Expand Down Expand Up @@ -955,10 +949,14 @@ def __init__(self, *args, **kwargs):

def call(self, args, training=False):
X, mask = args
attn_output = self.attn(query=X, value=X, key=X, training=training)
msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1)

attn_output = self.attn(query=X, value=X, key=X, training=training, attention_mask=mask)*msk_input
out1 = self.norm1(X + attn_output)

ffn_output = self.ffn(out1)
out2 = self.norm2(out1 + ffn_output)

return out2

class KernelDecoder(tf.keras.layers.Layer):
Expand All @@ -979,11 +977,12 @@ def __init__(self, *args, **kwargs):

def call(self, args, training=False):
X, enc_output, mask = args
msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1)

attn1 = self.attn1(query=X, value=X, key=X, training=training)
attn1 = self.attn1(query=X, value=X, key=X, training=training, attention_mask=mask)*msk_input
out1 = self.norm1(attn1 + X, training=training)

attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training)
attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training, attention_mask=mask)*msk_input
out2 = self.norm2(attn2 + out1)

ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
Expand All @@ -1007,14 +1006,15 @@ def __init__(self, *args, **kwargs):
super(Transformer, self).__init__(*args, **kwargs)

def call(self, inputs, training=False):
X, msk_input = inputs
X, mask = inputs
msk_input = tf.expand_dims(tf.cast(mask, tf.float32), -1)

for enc in self.encoders:
X = enc([X, msk_input], training=training)*msk_input
X = enc([X, mask], training=training)*msk_input

X_dec = X
for dec in self.decoders:
X_dec = dec([X_dec, X, msk_input], training=training)*msk_input
X_dec = dec([X_dec, X, mask], training=training)*msk_input

return X_dec

Expand All @@ -1039,8 +1039,8 @@ def __init__(self,
key_dim = 64
self.ffn = point_wise_feed_forward_network(key_dim, key_dim, "ffn", num_layers=1, activation="elu")

self.tf1 = Transformer(key_dim=key_dim, num_layers=2, name="tf1")
self.tf2 = Transformer(key_dim=key_dim, num_layers=2, name="tf2")
self.tf1 = Transformer(key_dim=key_dim, num_layers=5, name="tf1")
self.tf2 = Transformer(key_dim=key_dim, num_layers=5, name="tf2")

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
Expand All @@ -1051,14 +1051,14 @@ def call(self, inputs, training=False):
debugging_data = {}

#mask padded elements
msk = X[:, :, 0] != 0
msk = tf.cast(X[:, :, 0] != 0, tf.float32)
msk_input = tf.expand_dims(tf.cast(msk, tf.float32), -1)

X_enc = self.enc(X)
X_enc = self.ffn(X_enc)

X_enc_1 = self.tf1([X_enc, msk_input], training=training)
X_enc_2 = self.tf2([X_enc, msk_input], training=training)
X_enc_1 = self.tf1([X_enc, msk], training=training)
X_enc_2 = self.tf2([X_enc, msk], training=training)

ret = self.output_dec([X, X_enc_1, X_enc_2, msk_input], training=training)

Expand Down
16 changes: 14 additions & 2 deletions mlpf/tfmodel/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def __init__(self, outpath, dataset, dataset_info, plot_freq=1, comet_experiment
}

self.reg_bins = {
"pt": np.linspace(-100, 1000, 100),
"pt": np.linspace(-100, 200, 100),
"eta": np.linspace(-6, 6, 100),
"sin_phi": np.linspace(-1,1,100),
"cos_phi": np.linspace(-1,1,100),
"energy": np.linspace(-100, 5000, 100),
"energy": np.linspace(-100, 1000, 100),
}

def plot_cm(self, epoch, outpath, ypred_id, msk):
Expand Down Expand Up @@ -268,9 +268,21 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable):
#save scatterplot of raw values
plt.figure(figsize=(6,5))
bins = self.reg_bins[reg_variable]

if bins is None:
bins = 100

if reg_variable == "pt" or reg_variable == "energy":
bins = np.logspace(-2,3,100)
vals_true = np.log10(vals_true)
vals_pred = np.log10(vals_pred)
vals_true[np.isnan(vals_true)] = 0.0
vals_pred[np.isnan(vals_pred)] = 0.0

plt.hist2d(vals_true, vals_pred, bins=(bins, bins), cmin=1, cmap="Blues", norm=matplotlib.colors.LogNorm())
if reg_variable == "pt" or reg_variable == "energy":
plt.xscale("log")
plt.yscale("log")
plt.colorbar()

if len(vals_true) > 0:
Expand Down