Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better CMS dataset, fix f16, fix transformer #105

Merged
merged 1 commit into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hep_tfds
2 changes: 1 addition & 1 deletion mlpf/tallinn/cms-gen.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gpus 4
#SBATCH --gpus 8
#SBATCH --mem-per-gpu=8G

IMG=/home/software/singularity/tf-2.8.0.simg
Expand Down
143 changes: 111 additions & 32 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,30 +564,35 @@ def call(self, args, training=False):
if self.do_layernorm:
X_encoded = self.layernorm(X_encoded)

out_id_logits = self.ffn_id(X_encoded, training=training)*msk_input
out_id_logits = self.ffn_id(X_encoded, training=training)
out_id_logits = out_id_logits * tf.cast(msk_input, out_id_logits.dtype)

out_id_softmax = tf.nn.softmax(out_id_logits, axis=-1)
out_id_hard_softmax = tf.stop_gradient(tf.nn.softmax(100*out_id_logits, axis=-1))
out_charge = self.ffn_charge(X_encoded, training=training)*msk_input

orig_eta = X_input[:, :, 2:3]
out_charge = self.ffn_charge(X_encoded, training=training)
out_charge = out_charge * tf.cast(msk_input, out_charge.dtype)

orig_eta = tf.cast(X_input[:, :, 2:3], out_id_logits.dtype)

#FIXME: better schema propagation
#skip connection from raw input values
if self.schema == "cms":
orig_sin_phi = tf.math.sin(X_input[:, :, 3:4])*msk_input
orig_cos_phi = tf.math.cos(X_input[:, :, 3:4])*msk_input
orig_energy = X_input[:, :, 4:5]*msk_input
orig_sin_phi = tf.cast(tf.math.sin(X_input[:, :, 3:4])*msk_input, out_id_logits.dtype)
orig_cos_phi = tf.cast(tf.math.cos(X_input[:, :, 3:4])*msk_input, out_id_logits.dtype)
orig_energy = tf.cast(X_input[:, :, 4:5]*msk_input, out_id_logits.dtype)
elif self.schema == "delphes":
orig_sin_phi = X_input[:, :, 3:4]*msk_input
orig_cos_phi = X_input[:, :, 4:5]*msk_input
orig_energy = X_input[:, :, 5:6]*msk_input
orig_sin_phi = tf.cast(X_input[:, :, 3:4]*msk_input, out_id_logits.dtype)
orig_cos_phi = tf.cast(X_input[:, :, 4:5]*msk_input, out_id_logits.dtype)
orig_energy = tf.cast(X_input[:, :, 5:6]*msk_input, out_id_logits.dtype)

if self.regression_use_classification:
X_encoded = tf.concat([X_encoded, tf.stop_gradient(out_id_logits)], axis=-1)
X_encoded = tf.concat([X_encoded, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1)

pred_eta_corr = self.ffn_eta(X_encoded, training=training)*msk_input
pred_phi_corr = self.ffn_phi(X_encoded, training=training)*msk_input
pred_eta_corr = self.ffn_eta(X_encoded, training=training)
pred_eta_corr = pred_eta_corr*tf.cast(msk_input, pred_eta_corr.dtype)
pred_phi_corr = self.ffn_phi(X_encoded, training=training)
pred_phi_corr = pred_phi_corr*tf.cast(msk_input, pred_phi_corr.dtype)

if self.eta_skip_gate:
eta_gate = tf.keras.activations.sigmoid(pred_eta_corr[:, :, 0:1])
Expand All @@ -606,9 +611,10 @@ def call(self, args, training=False):

X_encoded_energy = tf.concat([X_encoded, X_encoded_energy], axis=-1)
if self.regression_use_classification:
X_encoded_energy = tf.concat([X_encoded_energy, tf.stop_gradient(out_id_logits)], axis=-1)
X_encoded_energy = tf.concat([X_encoded_energy, tf.cast(tf.stop_gradient(out_id_logits), X_encoded.dtype)], axis=-1)

pred_energy_corr = self.ffn_energy(X_encoded_energy, training=training)*msk_input
pred_energy_corr = self.ffn_energy(X_encoded_energy, training=training)
pred_energy_corr = pred_energy_corr*tf.cast(msk_input, pred_energy_corr.dtype)

#In case of a multimodal prediction, weight the per-class energy predictions by the approximately one-hot vector
if self.energy_multimodal:
Expand All @@ -619,7 +625,9 @@ def call(self, args, training=False):
#compute pt=E/cosh(eta)
orig_pt = tf.stop_gradient(pred_energy/tf.math.cosh(tf.clip_by_value(pred_eta, -8, 8)))

pred_pt_corr = self.ffn_pt(X_encoded_energy, training=training)*msk_input
pred_pt_corr = self.ffn_pt(X_encoded_energy, training=training)
pred_pt_corr = pred_pt_corr*tf.cast(msk_input, pred_pt_corr.dtype)

if self.pt_skip_gate:
pt_gate = tf.keras.activations.sigmoid(pred_pt_corr[:, :, 0:1])
pred_pt = orig_pt + pt_gate*pred_pt_corr[:, :, 1:2]
Expand Down Expand Up @@ -937,38 +945,105 @@ class KernelEncoder(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
from official.nlp.modeling.layers.kernel_attention import KernelAttention
self.key_dim = kwargs.pop("key_dim")
self.attn = KernelAttention(feature_transform="elu", num_heads=4, key_dim=self.key_dim, name=kwargs.get("name") + "_attention")
num_heads = 2

self.attn = KernelAttention(feature_transform="elu", num_heads=num_heads, key_dim=self.key_dim, name=kwargs.get("name") + "_attention")
self.ffn = point_wise_feed_forward_network(self.key_dim, self.key_dim, kwargs.get("name") + "_ffn", num_layers=1, activation="elu")
self.norm0 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0")
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1")
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0")
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1")
super(KernelEncoder, self).__init__(*args, **kwargs)

def call(self, args, training=False):
X, mask = args
X_attended = self.attn(X, X, training=training)
X = self.norm0(X + X_attended, training=training)
X = self.norm1(X + self.ffn(X), training=training)
return X
attn_output = self.attn(query=X, value=X, key=X, training=training)
out1 = self.norm1(X + attn_output)
ffn_output = self.ffn(out1)
out2 = self.norm2(out1 + ffn_output)
return out2

class KernelDecoder(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
from official.nlp.modeling.layers.kernel_attention import KernelAttention
self.key_dim = kwargs.pop("key_dim")
num_heads = 2

self.attn1 = KernelAttention(feature_transform="elu", num_heads=num_heads, key_dim=self.key_dim, name=kwargs.get("name") + "_attention1")
self.attn2 = KernelAttention(feature_transform="elu", num_heads=num_heads, key_dim=self.key_dim, name=kwargs.get("name") + "_attention2")

self.ffn = point_wise_feed_forward_network(self.key_dim, self.key_dim, kwargs.get("name") + "_ffn", num_layers=1, activation="elu")

self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0")
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1")
self.norm3 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln2")
super(KernelDecoder, self).__init__(*args, **kwargs)

def call(self, args, training=False):
X, enc_output, mask = args

attn1 = self.attn1(query=X, value=X, key=X, training=training)
out1 = self.norm1(attn1 + X, training=training)

attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training)
out2 = self.norm2(attn2 + out1)

ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
out3 = self.norm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)

return out3

class Transformer(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
self.encoders = []

key_dim = kwargs.pop("key_dim")
num_layers = kwargs.pop("num_layers")

for i in range(num_layers):
self.encoders.append(KernelEncoder(key_dim=key_dim, name="{}-enc{}".format(kwargs.get("name"), i)))

self.decoders = []
for i in range(num_layers):
self.decoders.append(KernelDecoder(key_dim=key_dim, name="{}-dec{}".format(kwargs.get("name"), i)))
super(Transformer, self).__init__(*args, **kwargs)

def call(self, inputs, training=False):
X, msk_input = inputs

for enc in self.encoders:
X = enc([X, msk_input], training=training)*msk_input

X_dec = X
for dec in self.decoders:
X_dec = dec([X_dec, X, msk_input], training=training)*msk_input

return X_dec


class PFNetTransformer(tf.keras.Model):
def __init__(self,
num_input_classes=8,
num_output_classes=3,
input_encoding="cms",
output_decoding={}):
schema="cms",
output_decoding={},
multi_output=True):
super(PFNetTransformer, self).__init__()

self.multi_output = multi_output

if input_encoding == "cms":
self.enc = InputEncodingCMS(num_input_classes)
elif input_encoding == "default":
self.enc = InputEncoding(num_input_classes)

key_dim = 128
key_dim = 64
self.ffn = point_wise_feed_forward_network(key_dim, key_dim, "ffn", num_layers=1, activation="elu")

self.encoders = []
for i in range(2):
self.encoders.append(KernelEncoder(key_dim=key_dim, name="enc{}".format(i)))
self.tf1 = Transformer(key_dim=key_dim, num_layers=2, name="tf1")
self.tf2 = Transformer(key_dim=key_dim, num_layers=2, name="tf2")

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
self.output_dec = OutputDecoding(**output_decoding)

def call(self, inputs, training=False):
Expand All @@ -980,10 +1055,14 @@ def call(self, inputs, training=False):
msk_input = tf.expand_dims(tf.cast(msk, tf.float32), -1)

X_enc = self.enc(X)

X_enc = self.ffn(X_enc)
for enc in self.encoders:
X_enc = enc([X_enc, msk_input], training=training)*msk_input
ret = self.output_dec([X, X_enc, X_enc, msk_input], training=training)

return ret
X_enc_1 = self.tf1([X_enc, msk_input], training=training)
X_enc_2 = self.tf2([X_enc, msk_input], training=training)

ret = self.output_dec([X, X_enc_1, X_enc_2, msk_input], training=training)

if self.multi_output:
return ret
else:
return tf.concat([ret["cls"], ret["charge"], ret["pt"], ret["eta"], ret["sin_phi"], ret["cos_phi"], ret["energy"]], axis=-1)
38 changes: 11 additions & 27 deletions mlpf/tfmodel/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def plot_event_visualization(self, epoch, outpath, ypred, ypred_id, msk, ievent=

#Plot the target particles
plt.axes(ax2)

msk = self.ytrue_id[ievent] != 0
eta = self.ytrue["eta"][ievent][msk]
sphi = self.ytrue["sin_phi"][ievent][msk]
Expand Down Expand Up @@ -266,47 +265,28 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable):
vals_pred = ypred[reg_variable][sel].flatten()
vals_true = self.ytrue[reg_variable][sel].flatten()

loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
loss_vals = loss(np.expand_dims(vals_true, -1), np.expand_dims(vals_pred, axis=-1)).numpy()

#save scatterplot of raw values
plt.figure()
plt.figure(figsize=(6,5))
bins = self.reg_bins[reg_variable]
if bins is None:
bins = 100
plt.hist2d(vals_true, vals_pred, bins=100, cmap="Blues")

plt.hist2d(vals_true, vals_pred, bins=(bins, bins), cmin=1, cmap="Blues", norm=matplotlib.colors.LogNorm())
plt.colorbar()

if len(vals_true) > 0:
minval = np.min(vals_true)
maxval = np.max(vals_true)
if not (math.isnan(minval) or math.isnan(maxval) or math.isinf(minval) or math.isinf(maxval)):
plt.plot([minval, maxval], [minval, maxval], color="black", ls="--", lw=0.5)
plt.xlabel("true")
plt.ylabel("predicted")
plt.title("{}, particle weighted, L={:.4f}".format(reg_variable, np.sum(loss_vals)))
plt.title(reg_variable)
image_path = str(outpath / "{}_cls{}_corr.png".format(reg_variable, icls))
plt.savefig(image_path, bbox_inches="tight")
if self.comet_experiment:
self.comet_experiment.log_image(image_path, step=epoch)
plt.close("all")

#save loss-weighted correlation histogram
plt.figure()
plt.hist2d(vals_true, vals_pred, bins=(bins, bins), weights=loss_vals, cmap="Blues")
plt.colorbar()
if len(vals_true) > 0:
minval = np.min(vals_true)
maxval = np.max(vals_true)
if not (math.isnan(minval) or math.isnan(maxval) or math.isinf(minval) or math.isinf(maxval)):
plt.plot([minval, maxval], [minval, maxval], color="black", ls="--", lw=0.5)
plt.xlabel("true")
plt.ylabel("predicted")
plt.title("{}, loss weighted, L={:.4f}".format(reg_variable, np.sum(loss_vals)))
image_path = str(outpath / "{}_cls{}_corr_weighted.png".format(reg_variable, icls))
plt.savefig(image_path, bbox_inches="tight")
if self.comet_experiment:
self.comet_experiment.log_image(image_path, step=epoch)

#Also plot the residuals, as we have the true and predicted values already available here
plt.figure()
residual = vals_true - vals_pred
Expand All @@ -326,7 +306,6 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable):
if self.comet_experiment:
self.comet_experiment.log_metric('residual_{}_cls{}_mean'.format(reg_variable, icls), np.mean(residual), step=epoch)
self.comet_experiment.log_metric('residual_{}_cls{}_std'.format(reg_variable, icls), np.std(residual), step=epoch)
self.comet_experiment.log_metric('val_loss_{}_cls{}'.format(reg_variable, icls), np.sum(loss_vals), step=epoch)

def plot_elem_to_pred(self, epoch, cp_dir, msk, ypred_id):
X_id = self.X[msk][:, 0]
Expand Down Expand Up @@ -480,7 +459,10 @@ def on_epoch_end(self, epoch, logs=None):

for variable in ["pt", "eta", "sin_phi", "cos_phi", "energy"]:
self.plot_reg_distribution(epoch, cp_dir_cls, ypred, ypred_id, icls, variable)
self.plot_corr(epoch, cp_dir_cls, ypred, ypred_id, icls, variable)
try:
self.plot_corr(epoch, cp_dir_cls, ypred, ypred_id, icls, variable)
except ValueError as e:
print("Could not draw corr plot: {}".format(e))

def prepare_callbacks(
callbacks_cfg, outdir,
Expand Down Expand Up @@ -604,8 +586,10 @@ def make_transformer(config, dtype):
kwargs[par] = config['parameters'][par]

model = PFNetTransformer(
multi_output=config["setup"]["multi_output"],
num_input_classes=config["dataset"]["num_input_classes"],
num_output_classes=config["dataset"]["num_output_classes"],
schema=config["dataset"]["schema"],
**kwargs
)
return model
Expand Down
1 change: 1 addition & 0 deletions mlpf/tfmodel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def load_and_interleave(dataset_names, config, num_gpus, split, batch_size):
steps = []
for ds_name in dataset_names:
ds, _ = get_heptfds_dataset(ds_name, config, num_gpus, split)
#ds = ds.take(500)
num_steps = ds.cardinality().numpy()
assert(num_steps > 0)
print("Loaded {}:{} with {} steps".format(ds_name, split, num_steps))
Expand Down
Loading