From 4f1f647f69e093ae89ba033e118a06c0068653c7 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 31 Aug 2021 11:48:41 +0300 Subject: [PATCH 01/17] update cms-dev model Former-commit-id: dff2f1dde6061f0088a54fff82fd63270b6ca1f0 --- mlpf/tfmodel/model.py | 11 +++--- parameters/cms-dev.yaml | 77 +++++++++++++++++++++-------------------- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index a40451eeb..58132546b 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -790,9 +790,10 @@ def set_trainable_named(self, layer_names): # msk2 = (pred_cls==icls) # import matplotlib # import matplotlib.pyplot as plt - # bins = np.linspace(0,6,100) # plt.figure(figsize=(4,4)) + # minval = np.min(y["energy"][msk1].numpy().flatten()) + # maxval = np.max(y["energy"][msk1].numpy().flatten()) # plt.scatter( # y["energy"][msk1&msk2].numpy().flatten(), # y_pred["energy"][msk1&msk2].numpy().flatten(), @@ -800,7 +801,7 @@ def set_trainable_named(self, layer_names): # ) # plt.xlabel("true") # plt.ylabel("pred") - # plt.plot([0,6], [0,6]) + # plt.plot([minval,maxval], [minval,maxval], color="black", ls="--", lw=1.0) # plt.savefig("train_cls{}_{}.png".format(icls, self.step), bbox_inches="tight") # plt.close("all") @@ -832,9 +833,10 @@ def set_trainable_named(self, layer_names): # msk2 = (pred_cls==icls) # import matplotlib # import matplotlib.pyplot as plt - # bins = np.linspace(0,6,100) # plt.figure(figsize=(4,4)) + # minval = np.min(y["energy"][msk1].numpy().flatten()) + # maxval = np.max(y["energy"][msk1].numpy().flatten()) # plt.scatter( # y["energy"][msk1&msk2].numpy().flatten(), # y_pred["energy"][msk1&msk2].numpy().flatten(), @@ -842,10 +844,11 @@ def set_trainable_named(self, layer_names): # ) # plt.xlabel("true") # plt.ylabel("pred") - # plt.plot([0,6], [0,6]) + # plt.plot([minval,maxval], [minval,maxval], color="black", ls="--", lw=1.0) # plt.savefig("test_cls{}_{}.png".format(icls, self.step), bbox_inches="tight") # plt.close("all") + # # Updates the metrics tracking the loss # self.compiled_loss(y, y_pred, sample_weights, regularization_losses=self.losses) # # Update the metrics. diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index d541a3f58..9cb4a785b 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -24,22 +24,20 @@ dataset: #(pt, eta, sin phi, cos phi, E) num_momentum_outputs: 5 classification_loss_coef: 1.0 - charge_loss_coef: 1.0 - pt_loss_coef: 100.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 eta_loss_coef: 100.0 - sin_phi_loss_coef: 100.0 - cos_phi_loss_coef: 100.0 - energy_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* - processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_gen/*.tfrecords + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords num_files_per_chunk: 1 validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* energy_loss: type: Huber - delta: 1.0 pt_loss: type: Huber - delta: 1.0 sin_phi_loss: type: Huber delta: 0.1 @@ -57,7 +55,7 @@ setup: train: yes weights: weights_config: - lr: 1e-3 + lr: 1e-4 batch_size: 4 num_events_train: 80000 num_events_test: 9000 @@ -80,33 +78,36 @@ sample_weights: parameters: model: gnn_dense input_encoding: cms - activation: gelu - layernorm: yes - hidden_dim: 256 - bin_size: 32 - distance_dim: 8 - dropout: 0.0 - graph_kernel: - type: NodePairTrainableKernel - output_dim: 8 - hidden_dim: 32 - num_layers: 2 - activation: gelu - num_graph_layers: 3 - node_message: - type: NodeMessageLearnable - output_dim: 256 + combined_graph_layer: + bin_size: 32 + max_num_bins: 500 + distance_dim: 8 + layernorm: no + dropout: 0.2 + kernel: + type: NodePairTrainableKernel + output_dim: 8 + hidden_dim: 32 + num_layers: 2 + activation: gelu + num_node_messages: 1 + node_message: + type: NodeMessageLearnable + output_dim: 128 + activation: gelu + hidden_dim: 128 + num_layers: 2 + aggregation_direction: dst hidden_dim: 128 - num_layers: 2 activation: gelu - aggregation_direction: dst - num_node_messages: 1 + num_graph_layers_common: 2 + num_graph_layers_energy: 2 output_decoding: activation: gelu regression_use_classification: yes - dropout: 0.0 + dropout: 0.2 - pt_skip_gate: yes + pt_skip_gate: no eta_skip_gate: yes phi_skip_gate: yes @@ -124,13 +125,13 @@ parameters: phi_hidden_dim: 256 energy_hidden_dim: 256 - id_num_layers: 4 + id_num_layers: 2 charge_num_layers: 2 - pt_num_layers: 3 - eta_num_layers: 3 - phi_num_layers: 3 - energy_num_layers: 3 - layernorm: yes + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no mask_reg_cls0: no skip_connection: yes @@ -141,6 +142,6 @@ timing: num_iter: 3 exponentialdecay: - decay_steps: 10000 - decay_rate: 0.99 + decay_steps: 1000 + decay_rate: 0.98 staircase: yes From 8aa60ade75a9263326fbc992688b54c125d365b1 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 31 Aug 2021 13:05:30 +0300 Subject: [PATCH 02/17] update cms-dev Former-commit-id: 7cb3d490a2ed7a7ca884bfd7c6d9c62caaa128b7 --- parameters/cms-dev.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index 9cb4a785b..2a7fed42b 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -2,7 +2,7 @@ backend: tensorflow dataset: schema: cms - target_particles: gen + target_particles: cand num_input_features: 15 num_output_features: 7 # NONE = 0, @@ -58,7 +58,7 @@ setup: lr: 1e-4 batch_size: 4 num_events_train: 80000 - num_events_test: 9000 + num_events_test: 10000 num_epochs: 100 num_val_files: 10 dtype: float32 From 84d6a64e73f493c4caaa8b297f6fe1797398e658 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 31 Aug 2021 22:56:06 +0300 Subject: [PATCH 03/17] add dist activation as configurable Former-commit-id: 9ce68d092ddee7bbfb8b4fe83cfec6e546908bd2 --- mlpf/tfmodel/model.py | 3 ++- parameters/cms-dev.yaml | 35 ++++++++++++++++------------------- parameters/cms.yaml | 1 + 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index 58132546b..f6c90a0fd 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -615,6 +615,7 @@ def __init__(self, *args, **kwargs): self.node_message = kwargs.pop("node_message") self.hidden_dim = kwargs.pop("hidden_dim") self.activation = getattr(tf.keras.activations, kwargs.pop("activation")) + self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation")) if self.do_layernorm: self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm") @@ -648,7 +649,7 @@ def call(self, x, msk, training=False): x = self.layernorm(x, training=training) #compute node features for graph building - x_dist = self.activation(self.ffn_dist(x, training=training)) + x_dist = self.dist_activation(self.ffn_dist(x, training=training)) #x_dist = self.gaussian_noise(x_dist, training=training) #compute the element-to-element messages / distance matrix / graph structure diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index 2a7fed42b..58578bce8 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -55,8 +55,8 @@ setup: train: yes weights: weights_config: - lr: 1e-4 - batch_size: 4 + lr: 1e-3 + batch_size: 5 num_events_train: 80000 num_events_test: 10000 num_epochs: 100 @@ -79,33 +79,30 @@ parameters: model: gnn_dense input_encoding: cms combined_graph_layer: - bin_size: 32 - max_num_bins: 500 - distance_dim: 8 + bin_size: 640 + max_num_bins: 100 + distance_dim: 128 layernorm: no - dropout: 0.2 + dropout: 0.0 + dist_activation: linear kernel: - type: NodePairTrainableKernel - output_dim: 8 - hidden_dim: 32 - num_layers: 2 - activation: gelu + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 num_node_messages: 1 node_message: - type: NodeMessageLearnable + type: GHConvDense output_dim: 128 activation: gelu - hidden_dim: 128 - num_layers: 2 - aggregation_direction: dst + normalize_degrees: yes hidden_dim: 128 activation: gelu - num_graph_layers_common: 2 - num_graph_layers_energy: 2 + num_graph_layers_common: 3 + num_graph_layers_energy: 3 output_decoding: activation: gelu regression_use_classification: yes - dropout: 0.2 + dropout: 0.0 pt_skip_gate: no eta_skip_gate: yes @@ -125,7 +122,7 @@ parameters: phi_hidden_dim: 256 energy_hidden_dim: 256 - id_num_layers: 2 + id_num_layers: 3 charge_num_layers: 2 pt_num_layers: 2 eta_num_layers: 2 diff --git a/parameters/cms.yaml b/parameters/cms.yaml index c0f5d5258..07bfb8d08 100644 --- a/parameters/cms.yaml +++ b/parameters/cms.yaml @@ -84,6 +84,7 @@ parameters: distance_dim: 128 layernorm: no dropout: 0.0 + dist_activation: gelu kernel: type: NodePairGaussianKernel dist_mult: 0.1 From 7e85121685d0ccb31f4c64ae99116de4b4d950dc Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Wed, 1 Sep 2021 11:27:29 +0300 Subject: [PATCH 04/17] more monitoring Former-commit-id: a2356f36367417eec38c6930c07cf92f90c15b7a --- mlpf/tfmodel/data.py | 2 +- mlpf/tfmodel/model.py | 34 +++- mlpf/tfmodel/model_setup.py | 118 +++++++++++- notebooks/cms-mlpf.ipynb | 358 ++++++++++++++++++++++++++++-------- notebooks/cmssw.ipynb | 4 +- parameters/cms-dev.yaml | 20 +- parameters/cms.yaml | 4 + 7 files changed, 447 insertions(+), 93 deletions(-) diff --git a/mlpf/tfmodel/data.py b/mlpf/tfmodel/data.py index d89786b8b..b9e0c0f57 100644 --- a/mlpf/tfmodel/data.py +++ b/mlpf/tfmodel/data.py @@ -237,7 +237,7 @@ def serialize_chunk(self, path, files, ichunk): Xs = np.concatenate(Xs) ys = np.concatenate(ys) - #set weights for each sample to be equal to the number of samples of this type + #set weights for each sample to be equal to the number of target particles of this type #in the training script, this can be used to compute either inverse or class-balanced weights uniq_vals, uniq_counts = np.unique(np.concatenate([y[:, 0] for y in ys]), return_counts=True) for i in range(len(ys)): diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index f6c90a0fd..d835deab1 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -601,6 +601,16 @@ def set_trainable_regression(self): self.ffn_eta.trainable = False self.ffn_pt.trainable = False self.ffn_energy.trainable = True + self.ffn_energy_classwise.trainable = True + + def set_trainable_classification(self): + self.ffn_id.trainable = True + self.ffn_charge.trainable = True + self.ffn_phi.trainable = False + self.ffn_eta.trainable = False + self.ffn_pt.trainable = False + self.ffn_energy.trainable = False + self.ffn_energy_classwise.trainable = False class CombinedGraphLayer(tf.keras.layers.Layer): def __init__(self, *args, **kwargs): @@ -670,6 +680,10 @@ def call(self, x, msk, training=False): class PFNetDense(tf.keras.Model): def __init__(self, + do_node_encoding=False, + hidden_dim=128, + dropout=0.0, + activation="gelu", multi_output=False, num_input_classes=8, num_output_classes=3, @@ -690,6 +704,21 @@ def __init__(self, self.debug = debug self.skip_connection = skip_connection + + self.do_node_encoding = do_node_encoding + self.hidden_dim = hidden_dim + self.dropout = dropout + self.activation = getattr(tf.keras.activations, activation) + + if self.do_node_encoding: + self.node_encoding = point_wise_feed_forward_network( + self.hidden_dim, + self.hidden_dim, + "node_encoding", + num_layers=2, + activation=self.activation, + dropout=self.dropout + ) if input_encoding == "cms": self.enc = InputEncodingCMS(num_input_classes) @@ -714,10 +743,13 @@ def call(self, inputs, training=False): #encode the elements for classification (id) enc = self.enc(X) - enc_cg = enc + encs = [] if self.skip_connection: encs.append(enc) + enc_cg = enc + if self.do_node_encoding: + enc_cg = self.node_encoding(enc_cg, training=training) for cg in self.cg: enc_all = cg(enc_cg, msk, training=training) enc_cg = enc_all["enc"] diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index e42fdbc18..a44d60a98 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -21,6 +21,7 @@ import random import math import platform +import mplhep from tqdm import tqdm from pathlib import Path from tfmodel.onecycle_scheduler import OneCycleScheduler, MomentumOneCycleScheduler @@ -180,7 +181,7 @@ def plot_event_visualization(self, epoch, outpath, ypred, ypred_id, msk, ievent= if self.comet_experiment: self.comet_experiment.log_image(image_path, step=epoch) - def plot_reg_distribution(self, outpath, ypred, ypred_id, icls, reg_variable): + def plot_reg_distribution(self, epoch, outpath, ypred, ypred_id, icls, reg_variable): if icls==0: vals_pred = ypred[reg_variable][ypred_id!=icls].flatten() @@ -203,8 +204,11 @@ def plot_reg_distribution(self, outpath, ypred, ypred_id, icls, reg_variable): plt.ylabel("Number of particles") plt.legend(loc="best") plt.title("Regression output, cls {}".format(icls)) - plt.savefig(str(outpath / "{}_cls{}.png".format(reg_variable, icls)), bbox_inches="tight") + image_path = str(outpath / "{}_cls{}.png".format(reg_variable, icls)) + plt.savefig(image_path, bbox_inches="tight") plt.close("all") + if self.comet_experiment: + self.comet_experiment.log_image(image_path, step=epoch) def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable): @@ -219,13 +223,13 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable): loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) loss_vals = loss(np.expand_dims(vals_true, -1), np.expand_dims(vals_pred, axis=-1)).numpy() - #save correlation histogram + #save scatterplot of raw values plt.figure() bins = self.reg_bins[reg_variable] if bins is None: bins = 100 - plt.hist2d(vals_true, vals_pred, bins=(bins, bins), cmap="Blues") - plt.colorbar() + plt.scatter(vals_true, vals_pred, marker=".", alpha=0.4) + if len(vals_true) > 0: minval = np.min(vals_true) maxval = np.max(vals_true) @@ -278,6 +282,91 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable): self.comet_experiment.log_metric('residual_{}_cls{}_std'.format(reg_variable, icls), np.std(residual), step=epoch) self.comet_experiment.log_metric('val_loss_{}_cls{}'.format(reg_variable, icls), np.sum(loss_vals), step=epoch) + def plot_elem_to_pred(self, epoch, cp_dir, msk, ypred_id): + X_id = self.X[msk][:, 0] + max_elem = int(np.max(X_id)) + cand_id = self.ytrue_id[msk] + pred_id = ypred_id[msk] + cm1 = sklearn.metrics.confusion_matrix(X_id, cand_id, labels=range(max_elem)) + cm2 = sklearn.metrics.confusion_matrix(X_id, pred_id, labels=range(max_elem)) + + plt.figure(figsize=(10,4)) + + ax = plt.subplot(1,2,1) + plt.title("Targets") + plt.imshow(cm1, cmap="Blues", norm=matplotlib.colors.LogNorm()) + plt.xticks(range(12)); + plt.yticks(range(12)); + plt.xlabel("Particle id") + plt.ylabel("PFElement id") + plt.colorbar() + + ax = plt.subplot(1,2,2) + plt.title("Predictions") + plt.imshow(cm2, cmap="Blues", norm=matplotlib.colors.LogNorm()) + plt.xticks(range(12)); + plt.yticks(range(12)); + plt.xlabel("Particle id") + plt.ylabel("PFElement id") + plt.colorbar() + + image_path = str(cp_dir / "elem_to_pred.png") + plt.savefig(image_path, bbox_inches="tight") + + if self.comet_experiment: + self.comet_experiment.log_image(image_path, step=epoch) + + def plot_eff_and_fake_rate( + self, + epoch, + icls, + msk, + ypred_id, + cp_dir, + ivar=4, + bins=np.linspace(-3,6,100), + xlabel="PFElement log[E/GeV]", + log=True + ): + + values = self.X[msk][:, ivar] + cand_id = self.ytrue_id[msk] + pred_id = ypred_id[msk] + + if log: + values = np.log(values) + + hist_cand = np.histogram(values[(cand_id==icls)], bins=bins); + hist_cand_true = np.histogram(values[(cand_id==icls) & (pred_id==icls)], bins=bins); + + hist_pred = np.histogram(values[(pred_id==icls)], bins=bins); + hist_pred_fake = np.histogram(values[(cand_id!=icls) & (pred_id==icls)], bins=bins); + + eff = hist_cand_true[0]/hist_cand[0] + fake = hist_pred_fake[0]/hist_pred[0] + + plt.figure(figsize=(8,8)) + ax = plt.subplot(2,1,1) + mplhep.histplot(hist_cand, label="PF") + mplhep.histplot(hist_pred, label="MLPF") + plt.legend() + plt.xlabel(xlabel) + plt.ylabel("Number of particles") + + ax = plt.subplot(2,1,2, sharex=ax) + mplhep.histplot(eff, bins=hist_cand[1], label="efficiency", color="black") + mplhep.histplot(fake, bins=hist_cand[1], label="fake rate", color="red") + plt.legend(frameon=False) + plt.ylim(0,1.4) + plt.xlabel(xlabel) + plt.ylabel("Fraction of particles / bin") + + image_path = str(cp_dir / "eff_fake_cls{}.png".format(icls)) + plt.savefig(image_path, bbox_inches="tight") + + if self.comet_experiment: + self.comet_experiment.log_image(image_path, step=epoch) + def on_epoch_end(self, epoch, logs=None): #save the training logs (losses) for this epoch @@ -302,6 +391,8 @@ def on_epoch_end(self, epoch, logs=None): #exclude padded elements from the plotting msk = self.X[:, :, 0] != 0 + self.plot_elem_to_pred(epoch, cp_dir, msk, ypred_id) + self.plot_cm(epoch, cp_dir, ypred_id, msk) for ievent in range(min(5, self.X.shape[0])): self.plot_event_visualization(epoch, cp_dir, ypred, ypred_id, msk, ievent=ievent) @@ -309,8 +400,12 @@ def on_epoch_end(self, epoch, logs=None): for icls in range(self.num_output_classes): cp_dir_cls = cp_dir / "cls_{}".format(icls) cp_dir_cls.mkdir(parents=True, exist_ok=True) + + if icls!=0: + self.plot_eff_and_fake_rate(epoch, icls, msk, ypred_id, cp_dir_cls) + for variable in ["pt", "eta", "sin_phi", "cos_phi", "energy"]: - self.plot_reg_distribution(cp_dir_cls, ypred, ypred_id, icls, variable) + self.plot_reg_distribution(epoch, cp_dir_cls, ypred, ypred_id, icls, variable) self.plot_corr(epoch, cp_dir_cls, ypred, ypred_id, icls, variable) np.savez(str(cp_dir/"pred.npz"), X=self.X, ytrue=self.y, **ypred) @@ -395,6 +490,10 @@ def make_model(config, dtype): def make_gnn_dense(config, dtype): parameters = [ + "do_node_encoding", + "hidden_dim", + "dropout", + "activation", "num_graph_layers_common", "num_graph_layers_energy", "input_encoding", @@ -548,6 +647,13 @@ def configure_model_weights(model, trainable_layers): cg.trainable = True model.output_dec.set_trainable_regression() + elif trainable_layers == "classification": + for cg in model.cg: + cg.trainable = True + for cg in model.cg_energy: + cg.trainable = False + + model.output_dec.set_trainable_classification() else: if isinstance(trainable_layers, str): trainable_layers = [trainable_layers] diff --git a/notebooks/cms-mlpf.ipynb b/notebooks/cms-mlpf.ipynb index b77607840..c78738457 100644 --- a/notebooks/cms-mlpf.ipynb +++ b/notebooks/cms-mlpf.ipynb @@ -130,7 +130,7 @@ "metadata": {}, "outputs": [], "source": [ - "path = \"../experiments/cms_20210830_181309_198541.joosep-desktop-work/evaluation/\"" + "path = \"../experiments/cms-dev_20210831_225815_541048.gpu0.local/evaluation/\"" ] }, { @@ -167,150 +167,358 @@ "ygen_f = ygen.reshape((ygen.shape[0]*ygen.shape[1], ygen.shape[2]))\n", "ycand_f = ycand.reshape((ycand.shape[0]*ycand.shape[1], ycand.shape[2]))\n", "ypred_f = ypred.reshape((ypred.shape[0]*ypred.shape[1], ypred.shape[2]))\n", - "ypred_raw_f = ypred_raw.reshape((ypred_raw.shape[0]*ypred_raw.shape[1], ypred_raw.shape[2]))" + "\n", + "# ypred_raw[X[:, :, 0]==1, 6] = 0.0\n", + "\n", + "# ypred_raw[X[:, :, 0]==4, 1] = 0.0\n", + "# #ypred_raw[X[:, :, 0]==4, 6] = 0.0\n", + "# ypred_raw[X[:, :, 0]==5, 1] = 0.0\n", + "# ypred_raw[X[:, :, 0]==5, 7] = 0.0\n", + "\n", + "# ypred_raw[X[:, :, 0]==8, 1] = 0.0\n", + "# ypred_raw[X[:, :, 0]==9, 1] = 0.0\n", + "\n", + "# ypred_raw[X[:, :, 0]==8, 2] = 0.0\n", + "# ypred_raw[X[:, :, 0]==9, 2] = 0.0\n", + "\n", + "ypred_raw_f = ypred_raw.reshape((ypred_raw.shape[0]*ypred_raw.shape[1], ypred_raw.shape[2]))\n", + "\n", + "ypred_id = np.argmax(ypred_raw, axis=-1)\n", + "\n", + "ypred_id_f = ypred_id.flatten()" ] }, { "cell_type": "code", "execution_count": null, - "id": "opened-lyric", + "id": "floral-people", "metadata": {}, "outputs": [], "source": [ - "icls = 2\n", - "msk = (ycand_f[:, 0]==icls) & (ypred_f[:, 0]==icls)\n", - "plt.scatter(ycand_f[msk, 2], 2*ypred_f[msk, 2], marker=\".\", alpha=0.4)\n", - "plt.plot([0,100], [0,100], color=\"black\", ls=\"--\")" + "np.unique(ypred_id[X[:, :, 0]==4], return_counts=True)" ] }, { "cell_type": "code", "execution_count": null, - "id": "linear-eleven", + "id": "cooked-bullet", "metadata": {}, "outputs": [], "source": [ - "np.std(ycand_f[ycand_f[:, 0]!=0, 4])" + "# thresholds = [0.6, 0.7, 0.0, 0, 0, 0, 0]\n", + "# ypred_id = apply_thresholds(ypred_raw, thresholds)\n", + "# ypred_id_f = apply_thresholds_f(ypred_raw_f, thresholds)" ] }, { "cell_type": "code", "execution_count": null, - "id": "based-wrestling", + "id": "virgin-nicaragua", "metadata": {}, "outputs": [], "source": [ - "plt.hist(np.log(ycand_f[ycand_f[:, 0]!=0, 6]), bins=100);" + "for icls in range(1,8):\n", + " npred = np.sum(ypred_id == icls, axis=1)\n", + " ncand = np.sum(ycand[:, :, 0] == icls, axis=1)\n", + " plt.figure(figsize=(6,6))\n", + " plt.scatter(ncand, npred, marker=\".\", alpha=0.8)\n", + " a = 0.5*min(np.min(npred), np.min(ncand))\n", + " b = 1.5*max(np.max(npred), np.max(ncand))\n", + " plt.xlim(a,b)\n", + " plt.ylim(a,b)\n", + " plt.plot([a,b],[a,b], color=\"black\", ls=\"--\")\n", + " plt.title(pid_names_long[icls],y=1.05)\n", + " plt.xlabel(\"number of PFCandidates\")\n", + " plt.ylabel(\"number of MLPFCandidates\")\n", + " cms_label(x2=0.6, y=0.89)\n", + " plt.savefig(\"num_cls{}.pdf\".format(icls))\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "documented-savage", + "id": "reported-button", "metadata": {}, "outputs": [], "source": [ - "glob_iter = 0\n", - "def multiplicity_score(thresholds):\n", - " global glob_iter\n", - " ypred_id = apply_thresholds(ypred_raw, thresholds)\n", - " total_scores = []\n", - " for icls in range(1,8):\n", - " ntrue = np.sum((ycand[:, :, 0]==icls)*msk_X, axis=1)\n", - " npred = np.sum((ypred_id==icls)*msk_X, axis=1)\n", - " diff = np.sqrt(np.sum((ntrue-npred)**2))/np.mean(ntrue)\n", - " total_scores.append(diff)\n", - " #print(\" \", icls, np.mean(ntrue), np.mean(npred), diff)\n", - " glob_iter += 1\n", - " if glob_iter%10 == 0:\n", - " print(glob_iter, np.sum(total_scores))\n", - " print(\",\\t\".join([\"{:.2f}\".format(x) for x in thresholds]))\n", - " print(\",\\t\".join([\"{:.2f}\".format(x) for x in total_scores]))\n", - " return np.sum(total_scores)\n", + "energy_bins_classwise = {\n", + " 1: [-2, 5],\n", + " 2: [-2, 6],\n", + " 3: [1, 7],\n", + " 4: [2, 5],\n", + " 5: [2, 5],\n", + " 6: [2, 5],\n", + " 7: [2, 5],\n", + "}\n", "\n", - "ret = scipy.optimize.minimize(\n", - " multiplicity_score,\n", - " 0.5*np.ones(7),\n", - " tol=1e-5,\n", - " method=\"Powell\",\n", - " bounds=[(0,1) for i in range(7)],\n", - " #options={\"ftol\": 1e-6, \"disp\":True}\n", - ")" + "energy_correction_factors = {\n", + " 1: [1, 1],\n", + " 2: [1, 1],\n", + " 3: [1.0, 1.2],\n", + " 4: [1, 1],\n", + " 5: [1, 1],\n", + " 6: [1, 1],\n", + " 7: [1, 1],\n", + "}" ] }, { "cell_type": "code", "execution_count": null, - "id": "resistant-abraham", + "id": "chronic-discovery", "metadata": {}, "outputs": [], "source": [ - "thresholds = 0.0*ret.x" + "b = np.linspace(0,1,101)\n", + "plt.figure(figsize=(4,4))\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==1) & (ycand_f[:, 0]==0), 1], bins=b, histtype=\"step\", lw=2, label=\"no PFCandidate\", density=True);\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==1) & (ycand_f[:, 0]==1), 1], bins=b, histtype=\"step\", lw=2, label=\"charged PFCandidate\", density=True);\n", + "plt.legend(loc=2, frameon=False)\n", + "plt.xlabel(\"Charged hadron probability\")\n", + "plt.title(\"Tracks\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "involved-tobago", + "id": "multiple-disco", "metadata": {}, "outputs": [], "source": [ - "thresholds = [0.5, 0.6, 0.45, 0.56, 0.2, 0.85, 0.19]" + "b = np.linspace(0,1,101)\n", + "plt.figure(figsize=(4,4))\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==1) & (ycand_f[:, 0]==0), 0], bins=b, histtype=\"step\", lw=2, label=\"no PFCandidate\", density=True);\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==1) & (ycand_f[:, 0]==1), 0], bins=b, histtype=\"step\", lw=2, label=\"charged PFCandidate\", density=True);\n", + "plt.legend(loc=1, frameon=False)\n", + "plt.xlabel(\"No particle probability\")\n", + "plt.title(\"Tracks\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "complex-difficulty", + "id": "innocent-black", "metadata": {}, "outputs": [], "source": [ - "ypred_id = apply_thresholds(ypred_raw, thresholds)\n", - "ypred_id_f = apply_thresholds_f(ypred_raw_f, thresholds)" + "b = np.linspace(0,1,101)\n", + "plt.figure(figsize=(4,4))\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==5) & (ycand_f[:, 0]==0), 2], bins=b, histtype=\"step\", lw=2, label=\"no PFCandidate\", density=True);\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==5) & (ycand_f[:, 0]==2), 2], bins=b, histtype=\"step\", lw=2, label=\"neutral PFCandidate\", density=True);\n", + "plt.legend(loc=2, frameon=False)\n", + "plt.xlabel(\"Neutral probability\")\n", + "plt.title(\"HCAL clusters\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "patient-thunder", + "id": "flying-mason", "metadata": {}, "outputs": [], "source": [ - "sklearn.metrics.balanced_accuracy_score(ycand_f[msk_X_f, 0], ypred_f[:, 0][msk_X_f])" + "b = np.linspace(0,1,101)\n", + "plt.figure(figsize=(4,4))\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==5) & (ycand_f[:, 0]==0), 0], bins=b, histtype=\"step\", lw=2, label=\"no PFCandidate\", density=True);\n", + "plt.hist(ypred_raw_f[(X_f[:, 0]==5) & (ycand_f[:, 0]==2), 0], bins=b, histtype=\"step\", lw=2, label=\"neutral PFCandidate\", density=True);\n", + "plt.legend(loc=\"best\", frameon=False)\n", + "plt.xlabel(\"No particle probability\")\n", + "plt.title(\"HCAL clusters\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "revised-pollution", + "id": "equipped-subcommittee", "metadata": {}, "outputs": [], "source": [ - "sklearn.metrics.balanced_accuracy_score(ycand_f[msk_X_f, 0], ypred_id_f[msk_X_f])" + "elem_type = 5\n", + "icls = 2\n", + "\n", + "def plot_elem_energy_cls_prob(elem_type):\n", + " plt.figure(figsize=(4*5,2*4))\n", + " plt.suptitle(\"PFElement type {}\".format(elem_type))\n", + " \n", + " for icls in range(8):\n", + " plt.subplot(2,4,icls+1)\n", + " plt.hist2d(\n", + " np.log(X_f[X_f[:, 0]==elem_type, 4]),\n", + " ypred_raw_f[X_f[:, 0]==elem_type, icls],\n", + " bins=(np.linspace(-2,6,100), np.linspace(0,1,100)), cmap=\"Blues\");\n", + " plt.colorbar()\n", + " plt.xlabel(\"PFElement log[E/GeV]\")\n", + " plt.ylabel(\"MLPF probability for class {}\".format(icls))\n", + " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, - "id": "virgin-nicaragua", + "id": "worst-coating", "metadata": {}, "outputs": [], "source": [ - "for icls in range(1,8):\n", - " npred = np.sum(ypred_id == icls, axis=1)\n", - " ncand = np.sum(ycand[:, :, 0] == icls, axis=1)\n", - " plt.figure(figsize=(6,6))\n", - " plt.scatter(ncand, npred, marker=\".\", alpha=0.8)\n", - " a = 0.5*min(np.min(npred), np.min(ncand))\n", - " b = 1.5*max(np.max(npred), np.max(ncand))\n", - " plt.xlim(a,b)\n", - " plt.ylim(a,b)\n", - " plt.plot([a,b],[a,b], color=\"black\", ls=\"--\")\n", - " plt.title(pid_names_long[icls],y=1.05)\n", - " plt.xlabel(\"number of PFCandidates\")\n", - " plt.ylabel(\"number of MLPFCandidates\")\n", - " cms_label(x2=0.6, y=0.89)\n", - " plt.savefig(\"num_cls{}.pdf\".format(icls))\n" + "plot_elem_energy_cls_prob(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "responsible-georgia", + "metadata": {}, + "outputs": [], + "source": [ + "plot_elem_energy_cls_prob(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "noble-guess", + "metadata": {}, + "outputs": [], + "source": [ + "plot_elem_energy_cls_prob(4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "honest-tackle", + "metadata": {}, + "outputs": [], + "source": [ + "reco_label = X_f[X_f[:, 0]!=0, 0]\n", + "cand_label = ycand_f[X_f[:, 0]!=0, 0]\n", + "pred_label = ypred_id_f[X_f[:, 0]!=0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fourth-approval", + "metadata": {}, + "outputs": [], + "source": [ + "cm1 = sklearn.metrics.confusion_matrix(reco_label, cand_label, labels=range(12))\n", + "cm2 = sklearn.metrics.confusion_matrix(reco_label, pred_label, labels=range(12))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abroad-wallpaper", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(cm1, cmap=\"Blues\", norm=matplotlib.colors.LogNorm())\n", + "plt.xticks(range(12));\n", + "plt.yticks(range(12));\n", + "plt.colorbar()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "honest-runner", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(cm2, cmap=\"Blues\", norm=matplotlib.colors.LogNorm())\n", + "plt.xticks(range(12));\n", + "plt.yticks(range(12));\n", + "plt.colorbar()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "incorporated-prerequisite", + "metadata": {}, + "outputs": [], + "source": [ + "ycand_id_f = ycand_f[:, 0]\n", + "\n", + "b = np.linspace(-3,6,100)\n", + "\n", + "icls = 2\n", + "\n", + "def plot_eff_and_fake_rate(\n", + " icls,\n", + " ivar=4,\n", + " bins=np.linspace(-3,6,100),\n", + " xlabel=\"PFElement log[E/GeV]\", log=True\n", + " ):\n", + " \n", + " values = X_f[:, ivar]\n", + " if log:\n", + " values = np.log(values)\n", + " \n", + " hist_cand = np.histogram(values[(ycand_id_f==icls)], bins=bins);\n", + " hist_cand_true = np.histogram(values[(ycand_id_f==icls) & (ypred_id_f==icls)], bins=bins);\n", + "\n", + " hist_pred = np.histogram(values[(ypred_id_f==icls)], bins=bins);\n", + " hist_pred_fake = np.histogram(values[(ycand_id_f!=icls) & (ypred_id_f==icls)], bins=bins);\n", + "\n", + " eff = hist_cand_true[0]/hist_cand[0]\n", + " fake = hist_pred_fake[0]/hist_pred[0]\n", + "\n", + " plt.figure(figsize=(8,8))\n", + " ax = plt.subplot(2,1,1)\n", + " mplhep.histplot(hist_cand, label=\"PF\")\n", + " mplhep.histplot(hist_pred, label=\"MLPF\")\n", + " plt.legend()\n", + " plt.xlabel(xlabel)\n", + " plt.ylabel(\"Number of particles\")\n", + "\n", + " ax = plt.subplot(2,1,2, sharex=ax)\n", + " mplhep.histplot(eff, bins=hist_cand[1], label=\"efficiency\", color=\"black\")\n", + " mplhep.histplot(fake, bins=hist_cand[1], label=\"fake rate\", color=\"red\")\n", + " plt.legend(frameon=False)\n", + " plt.ylim(0,1.4)\n", + " plt.xlabel(xlabel)\n", + " plt.ylabel(\"Fraction of particles / bin\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dedicated-indonesia", + "metadata": {}, + "outputs": [], + "source": [ + "plot_eff_and_fake_rate(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "variable-potter", + "metadata": {}, + "outputs": [], + "source": [ + "plot_eff_and_fake_rate(2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "hybrid-chuck", + "metadata": {}, + "outputs": [], + "source": [ + "plot_eff_and_fake_rate(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "golden-catalyst", + "metadata": {}, + "outputs": [], + "source": [ + "plot_eff_and_fake_rate(4)" ] }, { @@ -479,8 +687,8 @@ " msk = (ycand_f[:, 0] == icls)\n", " plt.hist(ypred_raw_f[msk & (X_f[:, 0] != 0), icls], bins=100, density=1, histtype=\"step\", lw=2, color=\"blue\", label=\"true \"+pid_names[icls]);\n", " plt.hist(ypred_raw_f[~msk & (X_f[:, 0] != 0), icls], bins=100, density=1, histtype=\"step\", lw=2, color=\"red\", label=\"other particles\");\n", - " #plt.axvline(ret.x[icls-1], 0, 0.7, ls=\"--\",\n", - " # color=\"black\", label=\"threshold: {:.2f}\".format(ret.x[icls-1]), lw=1)\n", + " plt.axvline(ret.x[icls-1], 0, 0.7, ls=\"--\",\n", + " color=\"black\", label=\"threshold: {:.2f}\".format(ret.x[icls-1]), lw=1)\n", " plt.yscale(\"log\")\n", " plt.title(\"Particle reconstruction for {}\".format(pid_names[icls]), y=1.05)\n", " plt.xlabel(\"Classification output {}\".format(icls))\n", @@ -501,16 +709,16 @@ "#perm = np.random.permutation(ycand_f[msk_X].shape[0])[:100000]\n", "\n", "cm_norm = sklearn.metrics.confusion_matrix(\n", - " ycand_f[msk_X_f & (ycand_f[:, 0]!=0), 0],\n", - " ypred_id_f[msk_X_f & (ycand_f[:, 0]!=0)],\n", - " labels=range(1,8),\n", + " ycand_f[msk_X_f, 0],\n", + " ypred_id_f[msk_X_f],\n", + " labels=range(0,8),\n", " normalize=\"true\"\n", ")\n", "\n", "cm = sklearn.metrics.confusion_matrix(\n", - " ycand_f[msk_X_f & (ycand_f[:, 0]!=0), 0],\n", - " ypred_id_f[msk_X_f & (ycand_f[:, 0]!=0)],\n", - " labels=range(1,8),\n", + " ycand_f[msk_X_f, 0],\n", + " ypred_id_f[msk_X_f],\n", + " labels=range(0,8),\n", ")" ] }, @@ -528,8 +736,8 @@ "\n", "cms_label(x1=0.18, x2=0.52, y=0.82)\n", "#sample_label(ax, x=0.8, y=1.0)\n", - "plt.xticks(range(len(y_labels)), y_labels);\n", - "plt.yticks(range(len(y_labels)), y_labels);\n", + "#plt.xticks(range(len(y_labels)), y_labels);\n", + "#plt.yticks(range(len(y_labels)), y_labels);\n", "plt.xlabel(\"Predicted PFCandidate\")\n", "plt.ylabel(\"True PFCandidate\")\n", "plt.title(\"MLPF trained on PF\", y=1.03)\n", @@ -656,7 +864,7 @@ { "cell_type": "code", "execution_count": null, - "id": "linear-ceramic", + "id": "dirty-rebecca", "metadata": {}, "outputs": [], "source": [] diff --git a/notebooks/cmssw.ipynb b/notebooks/cmssw.ipynb index da95c7bd2..be3f9855c 100644 --- a/notebooks/cmssw.ipynb +++ b/notebooks/cmssw.ipynb @@ -231,7 +231,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -245,7 +245,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index 58578bce8..2fbcc59dc 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -23,7 +23,7 @@ dataset: padded_num_elem_size: 6400 #(pt, eta, sin phi, cos phi, E) num_momentum_outputs: 5 - classification_loss_coef: 1.0 + classification_loss_coef: 100.0 charge_loss_coef: 0.01 pt_loss_coef: 0.0001 eta_loss_coef: 100.0 @@ -62,13 +62,13 @@ setup: num_epochs: 100 num_val_files: 10 dtype: float32 - trainable: - classification_loss_type: categorical_cross_entropy + trainable: classification + classification_loss_type: categorical_cross_entropy #categorical_cross_entropy, sigmoid_focal_crossentropy lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: - cls: inverse_sqrt - charge: signal_only + cls: none + charge: none pt: signal_only eta: signal_only sin_phi: signal_only @@ -78,6 +78,10 @@ sample_weights: parameters: model: gnn_dense input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu combined_graph_layer: bin_size: 640 max_num_bins: 100 @@ -97,7 +101,7 @@ parameters: normalize_degrees: yes hidden_dim: 128 activation: gelu - num_graph_layers_common: 3 + num_graph_layers_common: 4 num_graph_layers_energy: 3 output_decoding: activation: gelu @@ -108,14 +112,14 @@ parameters: eta_skip_gate: yes phi_skip_gate: yes - id_dim_decrease: yes + id_dim_decrease: no charge_dim_decrease: yes pt_dim_decrease: yes eta_dim_decrease: yes phi_dim_decrease: yes energy_dim_decrease: yes - id_hidden_dim: 256 + id_hidden_dim: 512 charge_hidden_dim: 256 pt_hidden_dim: 256 eta_hidden_dim: 256 diff --git a/parameters/cms.yaml b/parameters/cms.yaml index 07bfb8d08..7341bc09e 100644 --- a/parameters/cms.yaml +++ b/parameters/cms.yaml @@ -78,6 +78,10 @@ sample_weights: parameters: model: gnn_dense input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu combined_graph_layer: bin_size: 160 max_num_bins: 100 From c375da5a19f0df731358d91f1c12ae9ba6e8de31 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Wed, 1 Sep 2021 17:50:01 +0300 Subject: [PATCH 05/17] lsh configurable Former-commit-id: bc0c94013e36b811545430a200f452b1e93f5461 --- mlpf/tfmodel/model.py | 68 ++++++++++++--- mlpf/tfmodel/model_setup.py | 18 ++-- notebooks/cms-mlpf.ipynb | 161 ++++++++++++++++++++++++++---------- parameters/cms-dev.yaml | 17 ++-- 4 files changed, 193 insertions(+), 71 deletions(-) diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index d835deab1..d0dbd85a2 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -392,6 +392,33 @@ def call(self, x_msg, x_node, msk, training=False): return bins_split, x_features_binned, dm, msk_f_binned +class MessageBuildingLayerFull(tf.keras.layers.Layer): + def __init__(self, distance_dim=128, kernel=NodePairGaussianKernel(), **kwargs): + self.distance_dim = distance_dim + self.kernel = kernel + + super(MessageBuildingLayerFull, self).__init__(**kwargs) + + """ + x_msg: (n_batch, n_points, n_msg_features) + """ + def call(self, x_msg, msk, training=False): + msk_f = tf.expand_dims(tf.cast(msk, x_msg.dtype), -1) + + shp = tf.shape(x_msg) + n_batches = shp[0] + n_points = shp[1] + n_message_features = shp[2] + + #Run the node-to-node kernel (distance computation / graph building / attention) + dm = self.kernel(x_msg, training=training) + + #remove the masked points row-wise and column-wise + dm = tf.einsum("bijk,bi->bijk", dm, tf.squeeze(msk_f, axis=-1)) + dm = tf.einsum("bijk,bj->bijk", dm, tf.squeeze(msk_f, axis=-1)) + + return dm + class OutputDecoding(tf.keras.Model): def __init__(self, activation="elu", @@ -624,6 +651,7 @@ def __init__(self, *args, **kwargs): self.kernel = kwargs.pop("kernel") self.node_message = kwargs.pop("node_message") self.hidden_dim = kwargs.pop("hidden_dim") + self.do_lsh = kwargs.pop("do_lsh") self.activation = getattr(tf.keras.activations, kwargs.pop("activation")) self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation")) @@ -638,12 +666,20 @@ def __init__(self, *args, **kwargs): num_layers=2, activation=self.activation, dropout=self.dropout ) - self.message_building_layer = MessageBuildingLayerLSH( - distance_dim=self.distance_dim, - max_num_bins=self.max_num_bins, - bin_size=self.bin_size, - kernel=build_kernel_from_conf(self.kernel, kwargs.get("name")+"_kernel") - ) + + if self.do_lsh: + self.message_building_layer = MessageBuildingLayerLSH( + distance_dim=self.distance_dim, + max_num_bins=self.max_num_bins, + bin_size=self.bin_size, + kernel=build_kernel_from_conf(self.kernel, kwargs.get("name")+"_kernel") + ) + else: + self.message_building_layer = MessageBuildingLayerFull( + distance_dim=self.distance_dim, + kernel=build_kernel_from_conf(self.kernel, kwargs.get("name")+"_kernel") + ) + self.message_passing_layers = [ get_message_layer(self.node_message, "{}_msg_{}".format(kwargs.get("name"), iconv)) for iconv in range(self.num_node_messages) ] @@ -662,21 +698,29 @@ def call(self, x, msk, training=False): x_dist = self.dist_activation(self.ffn_dist(x, training=training)) #x_dist = self.gaussian_noise(x_dist, training=training) + #compute the element-to-element messages / distance matrix / graph structure - bins_split, x_binned, dm, msk_binned = self.message_building_layer(x_dist, x, msk) + if self.do_lsh: + bins_split, x, dm, msk_f = self.message_building_layer(x_dist, x, msk) + else: + dm = self.message_building_layer(x_dist, msk) + msk_f = tf.expand_dims(tf.cast(msk, x.dtype), axis=-1) + bins_split = None #run the node update with message passing for msg in self.message_passing_layers: - x_binned = msg((x_binned, dm, msk_binned)) + x = msg((x, dm, msk_f)) - #x_binned = self.gaussian_noise(x_binned, training=training) + #x = self.gaussian_noise(x, training=training) if self.dropout_layer: - x_binned = self.dropout_layer(x_binned, training=training) + x = self.dropout_layer(x, training=training) - x_enc = reverse_lsh(bins_split, x_binned) + #undo the binning according to the element-to-bin indices + if self.do_lsh: + x = reverse_lsh(bins_split, x) - return {"enc": x_enc, "dist": x_dist, "bins": bins_split, "dm": dm} + return {"enc": x, "dist": x_dist, "bins": bins_split, "dm": dm} class PFNetDense(tf.keras.Model): def __init__(self, diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index a44d60a98..b47822dde 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -324,16 +324,17 @@ def plot_eff_and_fake_rate( ypred_id, cp_dir, ivar=4, - bins=np.linspace(-3,6,100), - xlabel="PFElement log[E/GeV]", - log=True + bins=np.linspace(0, 200, 100), + xlabel="PFElement E", + log_var=False, + do_log_y=True ): values = self.X[msk][:, ivar] cand_id = self.ytrue_id[msk] pred_id = ypred_id[msk] - if log: + if log_var: values = np.log(values) hist_cand = np.histogram(values[(cand_id==icls)], bins=bins); @@ -352,12 +353,14 @@ def plot_eff_and_fake_rate( plt.legend() plt.xlabel(xlabel) plt.ylabel("Number of particles") + if do_log_y: + ax.set_yscale("log") ax = plt.subplot(2,1,2, sharex=ax) mplhep.histplot(eff, bins=hist_cand[1], label="efficiency", color="black") mplhep.histplot(fake, bins=hist_cand[1], label="fake rate", color="red") plt.legend(frameon=False) - plt.ylim(0,1.4) + plt.ylim(0, 1.4) plt.xlabel(xlabel) plt.ylabel("Fraction of particles / bin") @@ -503,7 +506,10 @@ def make_gnn_dense(config, dtype): "debug" ] - kwargs = {par: config['parameters'][par] for par in parameters} + kwargs = {} + for par in parameters: + if par in config['parameters'].keys(): + kwargs[par] = config['parameters'][par] model = PFNetDense( multi_output=config["setup"]["multi_output"], diff --git a/notebooks/cms-mlpf.ipynb b/notebooks/cms-mlpf.ipynb index c78738457..4424b3211 100644 --- a/notebooks/cms-mlpf.ipynb +++ b/notebooks/cms-mlpf.ipynb @@ -17,7 +17,7 @@ "import sklearn.metrics\n", "import matplotlib\n", "import scipy\n", - "import mplhep as hep\n", + "import mplhep\n", "\n", "import pandas" ] @@ -130,7 +130,7 @@ "metadata": {}, "outputs": [], "source": [ - "path = \"../experiments/cms-dev_20210831_225815_541048.gpu0.local/evaluation/\"" + "path = \"../experiments/cms-dev_20210901_112919_500542.gpu0.local/evaluation/\"" ] }, { @@ -172,10 +172,22 @@ "\n", "# ypred_raw[X[:, :, 0]==4, 1] = 0.0\n", "# #ypred_raw[X[:, :, 0]==4, 6] = 0.0\n", + "\n", + "# ypred_raw[X[:, :, 0]==5, 0] += ypred_raw[X[:, :, 0]==5, 1]\n", + "# ypred_raw[X[:, :, 0]==5, 0] += ypred_raw[X[:, :, 0]==5, 7]\n", "# ypred_raw[X[:, :, 0]==5, 1] = 0.0\n", "# ypred_raw[X[:, :, 0]==5, 7] = 0.0\n", "\n", + "# ypred_raw[X[:, :, 0]==8, 3] += ypred_raw[X[:, :, 0]==8, 1]\n", + "# ypred_raw[X[:, :, 0]==8, 3] += ypred_raw[X[:, :, 0]==8, 2]\n", "# ypred_raw[X[:, :, 0]==8, 1] = 0.0\n", + "# ypred_raw[X[:, :, 0]==8, 2] = 0.0\n", + "\n", + "\n", + "# ypred_raw[X[:, :, 0]==9, 3] += ypred_raw[X[:, :, 0]==9, 1]\n", + "# ypred_raw[X[:, :, 0]==9, 3] += ypred_raw[X[:, :, 0]==9, 2]\n", + "# ypred_raw[X[:, :, 0]==9, 1] = 0.0\n", + "# ypred_raw[X[:, :, 0]==9, 2] = 0.0\n", "# ypred_raw[X[:, :, 0]==9, 1] = 0.0\n", "\n", "# ypred_raw[X[:, :, 0]==8, 2] = 0.0\n", @@ -184,14 +196,13 @@ "ypred_raw_f = ypred_raw.reshape((ypred_raw.shape[0]*ypred_raw.shape[1], ypred_raw.shape[2]))\n", "\n", "ypred_id = np.argmax(ypred_raw, axis=-1)\n", - "\n", "ypred_id_f = ypred_id.flatten()" ] }, { "cell_type": "code", "execution_count": null, - "id": "floral-people", + "id": "corrected-tunisia", "metadata": {}, "outputs": [], "source": [ @@ -201,13 +212,52 @@ { "cell_type": "code", "execution_count": null, - "id": "cooked-bullet", + "id": "extensive-kuwait", + "metadata": {}, + "outputs": [], + "source": [ + "thresholds = [0.0, 0.0, 0.0, 0, 0, 0, 0]\n", + "ypred_id = apply_thresholds(ypred_raw, thresholds)\n", + "ypred_id_f = apply_thresholds_f(ypred_raw_f, thresholds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "interim-chosen", + "metadata": {}, + "outputs": [], + "source": [ + "icls = 2\n", + "ielem = 5\n", + "\n", + "energy_msk = (X_f[:, 4]>0)\n", + "elem_msk = (X_f[:, 0]==ielem)\n", + "\n", + "vals_sig = ypred_raw_f[energy_msk & elem_msk & (ycand_f[:, 0]==icls), icls]\n", + "vals_bkg = ypred_raw_f[energy_msk & elem_msk & (ycand_f[:, 0]!=icls), icls]\n", + "hsig = np.histogram(vals_sig, bins=b)[0]\n", + "hbkg = np.histogram(vals_bkg, bins=b)[0]\n", + "\n", + "a = np.cumsum(hsig)/np.sum(hsig)\n", + "b = np.cumsum(hbkg)/np.sum(hbkg)\n", + "\n", + "plt.figure(figsize=(4,4))\n", + "plt.plot(a, b, marker=\".\")\n", + "plt.plot([0,1], [0,1], color=\"black\", lw=0.5, ls=\"--\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "becoming-application", "metadata": {}, "outputs": [], "source": [ - "# thresholds = [0.6, 0.7, 0.0, 0, 0, 0, 0]\n", - "# ypred_id = apply_thresholds(ypred_raw, thresholds)\n", - "# ypred_id_f = apply_thresholds_f(ypred_raw_f, thresholds)" + "b = np.linspace(0,1,100)\n", + "mplhep.histplot(np.histogram(vals_sig, bins=b, density=1), label=\"sig\");\n", + "mplhep.histplot(np.histogram(vals_bkg, bins=b, density=1), label=\"bkg\");\n", + "plt.legend(loc=2)" ] }, { @@ -237,7 +287,7 @@ { "cell_type": "code", "execution_count": null, - "id": "reported-button", + "id": "funky-destination", "metadata": {}, "outputs": [], "source": [ @@ -265,7 +315,7 @@ { "cell_type": "code", "execution_count": null, - "id": "chronic-discovery", + "id": "authorized-greensboro", "metadata": {}, "outputs": [], "source": [ @@ -281,7 +331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "multiple-disco", + "id": "incorporate-vanilla", "metadata": {}, "outputs": [], "source": [ @@ -297,7 +347,7 @@ { "cell_type": "code", "execution_count": null, - "id": "innocent-black", + "id": "comic-privacy", "metadata": {}, "outputs": [], "source": [ @@ -313,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "flying-mason", + "id": "sustainable-passage", "metadata": {}, "outputs": [], "source": [ @@ -329,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "equipped-subcommittee", + "id": "funny-batch", "metadata": {}, "outputs": [], "source": [ @@ -343,9 +393,9 @@ " for icls in range(8):\n", " plt.subplot(2,4,icls+1)\n", " plt.hist2d(\n", - " np.log(X_f[X_f[:, 0]==elem_type, 4]),\n", + " np.log10(X_f[X_f[:, 0]==elem_type, 4]),\n", " ypred_raw_f[X_f[:, 0]==elem_type, icls],\n", - " bins=(np.linspace(-2,6,100), np.linspace(0,1,100)), cmap=\"Blues\");\n", + " bins=(np.linspace(-2,4,100), np.linspace(0,1,100)), cmap=\"Blues\");\n", " plt.colorbar()\n", " plt.xlabel(\"PFElement log[E/GeV]\")\n", " plt.ylabel(\"MLPF probability for class {}\".format(icls))\n", @@ -355,7 +405,7 @@ { "cell_type": "code", "execution_count": null, - "id": "worst-coating", + "id": "strange-combine", "metadata": {}, "outputs": [], "source": [ @@ -365,27 +415,27 @@ { "cell_type": "code", "execution_count": null, - "id": "responsible-georgia", + "id": "private-communication", "metadata": {}, "outputs": [], "source": [ - "plot_elem_energy_cls_prob(5)" + "plot_elem_energy_cls_prob(4)" ] }, { "cell_type": "code", "execution_count": null, - "id": "noble-guess", + "id": "differential-steal", "metadata": {}, "outputs": [], "source": [ - "plot_elem_energy_cls_prob(4)" + "plot_elem_energy_cls_prob(5)" ] }, { "cell_type": "code", "execution_count": null, - "id": "honest-tackle", + "id": "direct-crowd", "metadata": {}, "outputs": [], "source": [ @@ -397,7 +447,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fourth-approval", + "id": "fitting-thriller", "metadata": {}, "outputs": [], "source": [ @@ -408,7 +458,7 @@ { "cell_type": "code", "execution_count": null, - "id": "abroad-wallpaper", + "id": "frozen-ethnic", "metadata": {}, "outputs": [], "source": [ @@ -421,7 +471,7 @@ { "cell_type": "code", "execution_count": null, - "id": "honest-runner", + "id": "anticipated-robinson", "metadata": {}, "outputs": [], "source": [ @@ -434,7 +484,7 @@ { "cell_type": "code", "execution_count": null, - "id": "incorporated-prerequisite", + "id": "micro-saying", "metadata": {}, "outputs": [], "source": [ @@ -465,60 +515,83 @@ " fake = hist_pred_fake[0]/hist_pred[0]\n", "\n", " plt.figure(figsize=(8,8))\n", - " ax = plt.subplot(2,1,1)\n", - " mplhep.histplot(hist_cand, label=\"PF\")\n", - " mplhep.histplot(hist_pred, label=\"MLPF\")\n", + " ax1 = plt.subplot(2,1,1)\n", + " mplhep.histplot(hist_cand, label=\"with PF candidate\")\n", + " mplhep.histplot(hist_pred, label=\"with MLPF candidate\")\n", " plt.legend()\n", " plt.xlabel(xlabel)\n", " plt.ylabel(\"Number of particles\")\n", "\n", - " ax = plt.subplot(2,1,2, sharex=ax)\n", + " ax2 = plt.subplot(2,1,2, sharex=ax1)\n", " mplhep.histplot(eff, bins=hist_cand[1], label=\"efficiency\", color=\"black\")\n", " mplhep.histplot(fake, bins=hist_cand[1], label=\"fake rate\", color=\"red\")\n", " plt.legend(frameon=False)\n", " plt.ylim(0,1.4)\n", " plt.xlabel(xlabel)\n", - " plt.ylabel(\"Fraction of particles / bin\")" + " plt.ylabel(\"Fraction of particles / bin\")\n", + " \n", + " return ax1, ax2" ] }, { "cell_type": "code", "execution_count": null, - "id": "dedicated-indonesia", + "id": "inner-christianity", "metadata": {}, "outputs": [], "source": [ - "plot_eff_and_fake_rate(1)" + "b = np.linspace(0,100, 100)\n", + "plt.hist(X_f[(X_f[:, 0]==5), 4], bins=b, histtype=\"step\", lw=2, label=\"all clusters\");\n", + "plt.hist(X_f[(X_f[:, 0]==5) & (ycand_f[:, 0]==2), 4], bins=b, histtype=\"step\", lw=2, label=\"with PF candidate\");\n", + "plt.hist(X_f[(X_f[:, 0]==5) & (ypred_id_f==2), 4], bins=b, histtype=\"step\", lw=2, label=\"with MLPF candidate\");\n", + "plt.yscale(\"log\")\n", + "plt.legend()" ] }, { "cell_type": "code", "execution_count": null, - "id": "variable-potter", + "id": "automated-quarter", "metadata": {}, "outputs": [], "source": [ - "plot_eff_and_fake_rate(2)" + "ax1, ax2 = plot_eff_and_fake_rate(1, bins=np.linspace(0, 300, 100), log=False)\n", + "ax1.set_yscale(\"log\")\n", + "ax1.set_title(\"track, charged hadron predictions\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "hybrid-chuck", + "id": "military-professor", "metadata": {}, "outputs": [], "source": [ - "plot_eff_and_fake_rate(3)" + "ax1, ax2 = plot_eff_and_fake_rate(2, bins=np.linspace(0, 300, 100), log=False)\n", + "ax1.set_yscale(\"log\")\n", + "ax1.set_title(\"HCAL cluster, neutral hadron predictions\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "golden-catalyst", + "id": "characteristic-colleague", "metadata": {}, "outputs": [], "source": [ - "plot_eff_and_fake_rate(4)" + "ax1, ax2 = plot_eff_and_fake_rate(3, bins=np.linspace(0, 300, 100), log=False)\n", + "ax1.set_yscale(\"log\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "composed-principal", + "metadata": {}, + "outputs": [], + "source": [ + "ax1, ax2 = plot_eff_and_fake_rate(4, bins=np.linspace(0, 300, 100), log=False)\n", + "ax1.set_yscale(\"log\")" ] }, { @@ -687,8 +760,6 @@ " msk = (ycand_f[:, 0] == icls)\n", " plt.hist(ypred_raw_f[msk & (X_f[:, 0] != 0), icls], bins=100, density=1, histtype=\"step\", lw=2, color=\"blue\", label=\"true \"+pid_names[icls]);\n", " plt.hist(ypred_raw_f[~msk & (X_f[:, 0] != 0), icls], bins=100, density=1, histtype=\"step\", lw=2, color=\"red\", label=\"other particles\");\n", - " plt.axvline(ret.x[icls-1], 0, 0.7, ls=\"--\",\n", - " color=\"black\", label=\"threshold: {:.2f}\".format(ret.x[icls-1]), lw=1)\n", " plt.yscale(\"log\")\n", " plt.title(\"Particle reconstruction for {}\".format(pid_names[icls]), y=1.05)\n", " plt.xlabel(\"Classification output {}\".format(icls))\n", @@ -754,13 +825,13 @@ "source": [ "plt.figure(figsize=(8, 8))\n", "ax = plt.axes()\n", - "plt.imshow(cm, cmap=\"Blues\", norm=matplotlib.colors.LogNorm())\n", + "plt.imshow(cm, cmap=\"Blues\")\n", "plt.colorbar()\n", "\n", "cms_label(x1=0.18, x2=0.52, y=0.82)\n", "#sample_label(ax, x=0.8, y=1.0)\n", - "plt.xticks(range(len(y_labels)), y_labels);\n", - "plt.yticks(range(len(y_labels)), y_labels);\n", + "#plt.xticks(range(len(y_labels)), y_labels);\n", + "#plt.yticks(range(len(y_labels)), y_labels);\n", "plt.xlabel(\"Predicted PFCandidate\")\n", "plt.ylabel(\"True PFCandidate\")\n", "plt.title(\"MLPF trained on PF\", y=1.03)\n", @@ -864,7 +935,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dirty-rebecca", + "id": "scheduled-worst", "metadata": {}, "outputs": [], "source": [] diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index 2fbcc59dc..0e5cc101f 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -56,7 +56,7 @@ setup: weights: weights_config: lr: 1e-3 - batch_size: 5 + batch_size: 2 num_events_train: 80000 num_events_test: 10000 num_epochs: 100 @@ -83,8 +83,9 @@ parameters: dropout: 0.0 activation: gelu combined_graph_layer: - bin_size: 640 - max_num_bins: 100 + do_lsh: no + bin_size: 1600 + max_num_bins: 10 distance_dim: 128 layernorm: no dropout: 0.0 @@ -93,16 +94,16 @@ parameters: type: NodePairGaussianKernel dist_mult: 0.1 clip_value_low: 0.0 - num_node_messages: 1 + num_node_messages: 2 node_message: type: GHConvDense - output_dim: 128 + output_dim: 512 activation: gelu normalize_degrees: yes - hidden_dim: 128 + hidden_dim: 256 activation: gelu - num_graph_layers_common: 4 - num_graph_layers_energy: 3 + num_graph_layers_common: 1 + num_graph_layers_energy: 1 output_decoding: activation: gelu regression_use_classification: yes From 4347706280399422e87ea3193f9391dfaefdef73 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 09:51:46 +0300 Subject: [PATCH 06/17] added LSH scanning Former-commit-id: 4a60f46468b346a98d7de965599579f52805ee94 --- mlpf/tallinn/test-gnn.sh | 14 +++ mlpf/tfmodel/model_setup.py | 9 +- parameters/cms.yaml | 3 +- parameters/test-gnn/cms-0l.yaml | 149 ++++++++++++++++++++++++++ parameters/test-gnn/cms-lsh-1l.yaml | 149 ++++++++++++++++++++++++++ parameters/test-gnn/cms-lsh-2l.yaml | 149 ++++++++++++++++++++++++++ parameters/test-gnn/cms-lsh-3l.yaml | 149 ++++++++++++++++++++++++++ parameters/test-gnn/cms-nolsh-1l.yaml | 149 ++++++++++++++++++++++++++ 8 files changed, 768 insertions(+), 3 deletions(-) create mode 100755 mlpf/tallinn/test-gnn.sh create mode 100644 parameters/test-gnn/cms-0l.yaml create mode 100644 parameters/test-gnn/cms-lsh-1l.yaml create mode 100644 parameters/test-gnn/cms-lsh-2l.yaml create mode 100644 parameters/test-gnn/cms-lsh-3l.yaml create mode 100644 parameters/test-gnn/cms-nolsh-1l.yaml diff --git a/mlpf/tallinn/test-gnn.sh b/mlpf/tallinn/test-gnn.sh new file mode 100755 index 000000000..15017b885 --- /dev/null +++ b/mlpf/tallinn/test-gnn.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH -p gpu +#SBATCH --gpus 1 +#SBATCH --mem-per-gpu=8G + +IMG=/home/software/singularity/base.simg:latest +cd ~/particleflow + +#TF training +singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-0l.yaml --plot-freq 10 +singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-lsh-1l.yaml --plot-freq 10 +singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-lsh-2l.yaml --plot-freq 10 +singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-lsh-3l.yaml --plot-freq 10 +singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-nolsh-1l.yaml --plot-freq 10 diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index b47822dde..aef53ee6e 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -193,6 +193,8 @@ def plot_reg_distribution(self, epoch, outpath, ypred, ypred_id, icls, reg_varia bins = self.reg_bins[reg_variable] if bins is None: bins = 100 + + plt.figure() plt.hist(vals_true, bins=bins, histtype="step", lw=2, label="true") plt.hist(vals_pred, bins=bins, histtype="step", lw=2, label="predicted") @@ -312,6 +314,7 @@ def plot_elem_to_pred(self, epoch, cp_dir, msk, ypred_id): image_path = str(cp_dir / "elem_to_pred.png") plt.savefig(image_path, bbox_inches="tight") + plt.close("all") if self.comet_experiment: self.comet_experiment.log_image(image_path, step=epoch) @@ -366,6 +369,7 @@ def plot_eff_and_fake_rate( image_path = str(cp_dir / "eff_fake_cls{}.png".format(icls)) plt.savefig(image_path, bbox_inches="tight") + plt.close("all") if self.comet_experiment: self.comet_experiment.log_image(image_path, step=epoch) @@ -376,8 +380,9 @@ def on_epoch_end(self, epoch, logs=None): with open("{}/history_{}.json".format(self.outpath, epoch), "w") as fi: json.dump(logs, fi) - if epoch%self.plot_freq!=0: - return + if self.plot_freq>1: + if (epoch+1)%self.plot_freq!=0 or epoch==0: + return cp_dir = Path(self.outpath) / "epoch_{}".format(epoch) cp_dir.mkdir(parents=True, exist_ok=True) diff --git a/parameters/cms.yaml b/parameters/cms.yaml index 7341bc09e..61e0b6ba3 100644 --- a/parameters/cms.yaml +++ b/parameters/cms.yaml @@ -55,7 +55,7 @@ setup: train: yes weights: weights_config: - lr: 1e-4 + lr: 1e-3 batch_size: 5 num_events_train: 80000 num_events_test: 10000 @@ -83,6 +83,7 @@ parameters: dropout: 0.0 activation: gelu combined_graph_layer: + do_lsh: yes bin_size: 160 max_num_bins: 100 distance_dim: 128 diff --git a/parameters/test-gnn/cms-0l.yaml b/parameters/test-gnn/cms-0l.yaml new file mode 100644 index 000000000..3e230c7cb --- /dev/null +++ b/parameters/test-gnn/cms-0l.yaml @@ -0,0 +1,149 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 20 + num_events_train: 1000 + num_events_test: 1000 + num_epochs: 50 + num_val_files: 20 + dtype: float32 + trainable: + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu + combined_graph_layer: + do_lsh: no + bin_size: 160 + max_num_bins: 100 + distance_dim: 128 + layernorm: no + dropout: 0.0 + dist_activation: gelu + kernel: + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 + num_node_messages: 1 + node_message: + type: GHConvDense + output_dim: 128 + activation: gelu + normalize_degrees: yes + hidden_dim: 128 + activation: gelu + num_graph_layers_common: 0 + num_graph_layers_energy: 0 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 256 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 2 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 1000 + decay_rate: 0.98 + staircase: yes diff --git a/parameters/test-gnn/cms-lsh-1l.yaml b/parameters/test-gnn/cms-lsh-1l.yaml new file mode 100644 index 000000000..bdf62a034 --- /dev/null +++ b/parameters/test-gnn/cms-lsh-1l.yaml @@ -0,0 +1,149 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 10 + num_events_train: 1000 + num_events_test: 1000 + num_epochs: 50 + num_val_files: 20 + dtype: float32 + trainable: + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu + combined_graph_layer: + do_lsh: yes + bin_size: 160 + max_num_bins: 100 + distance_dim: 128 + layernorm: no + dropout: 0.0 + dist_activation: gelu + kernel: + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 + num_node_messages: 1 + node_message: + type: GHConvDense + output_dim: 128 + activation: gelu + normalize_degrees: yes + hidden_dim: 128 + activation: gelu + num_graph_layers_common: 1 + num_graph_layers_energy: 1 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 256 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 2 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 1000 + decay_rate: 0.98 + staircase: yes diff --git a/parameters/test-gnn/cms-lsh-2l.yaml b/parameters/test-gnn/cms-lsh-2l.yaml new file mode 100644 index 000000000..69320ceba --- /dev/null +++ b/parameters/test-gnn/cms-lsh-2l.yaml @@ -0,0 +1,149 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 5 + num_events_train: 1000 + num_events_test: 1000 + num_epochs: 50 + num_val_files: 20 + dtype: float32 + trainable: + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu + combined_graph_layer: + do_lsh: yes + bin_size: 160 + max_num_bins: 100 + distance_dim: 128 + layernorm: no + dropout: 0.0 + dist_activation: gelu + kernel: + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 + num_node_messages: 1 + node_message: + type: GHConvDense + output_dim: 128 + activation: gelu + normalize_degrees: yes + hidden_dim: 128 + activation: gelu + num_graph_layers_common: 2 + num_graph_layers_energy: 2 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 256 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 2 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 1000 + decay_rate: 0.98 + staircase: yes diff --git a/parameters/test-gnn/cms-lsh-3l.yaml b/parameters/test-gnn/cms-lsh-3l.yaml new file mode 100644 index 000000000..5cf0226c0 --- /dev/null +++ b/parameters/test-gnn/cms-lsh-3l.yaml @@ -0,0 +1,149 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 5 + num_events_train: 1000 + num_events_test: 1000 + num_epochs: 50 + num_val_files: 20 + dtype: float32 + trainable: + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu + combined_graph_layer: + do_lsh: yes + bin_size: 160 + max_num_bins: 100 + distance_dim: 128 + layernorm: no + dropout: 0.0 + dist_activation: gelu + kernel: + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 + num_node_messages: 1 + node_message: + type: GHConvDense + output_dim: 128 + activation: gelu + normalize_degrees: yes + hidden_dim: 128 + activation: gelu + num_graph_layers_common: 3 + num_graph_layers_energy: 3 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 256 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 2 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 1000 + decay_rate: 0.98 + staircase: yes diff --git a/parameters/test-gnn/cms-nolsh-1l.yaml b/parameters/test-gnn/cms-nolsh-1l.yaml new file mode 100644 index 000000000..edb43d666 --- /dev/null +++ b/parameters/test-gnn/cms-nolsh-1l.yaml @@ -0,0 +1,149 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 2 + num_events_train: 1000 + num_events_test: 1000 + num_epochs: 50 + num_val_files: 20 + dtype: float32 + trainable: + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu + combined_graph_layer: + do_lsh: no + bin_size: 160 + max_num_bins: 100 + distance_dim: 128 + layernorm: no + dropout: 0.0 + dist_activation: gelu + kernel: + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 + num_node_messages: 1 + node_message: + type: GHConvDense + output_dim: 128 + activation: gelu + normalize_degrees: yes + hidden_dim: 128 + activation: gelu + num_graph_layers_common: 1 + num_graph_layers_energy: 1 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 256 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 2 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 1000 + decay_rate: 0.98 + staircase: yes From 40e08510e5471ab1d8380ea96cb9ca7e7e94e978 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 14:46:33 +0300 Subject: [PATCH 07/17] up Former-commit-id: 63446fd09ce4ed6d773b68e331b628f01935b881 --- mlpf/tfmodel/model.py | 10 ++++++---- parameters/cms-dev.yaml | 44 ++++++++++++++++++++++------------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index d0dbd85a2..08200745f 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -37,10 +37,12 @@ def pairwise_gaussian_dist(A, B): def pairwise_learnable_dist(A, B, ffn, training=False): shp = tf.shape(A) + # tf.print("shp", shp) + # import pdb;pdb.set_trace() #stack node feature vectors of src[i], dst[j] into a matrix res[i,j] = (src[i], dst[j]) - a, b, c, d = tf.meshgrid(tf.range(shp[0]), tf.range(shp[1]), tf.range(shp[2]), tf.range(shp[2]), indexing="ij") - inds1 = tf.stack([a,b,c], axis=-1) - inds2 = tf.stack([a,b,d], axis=-1) + mg = tf.meshgrid(tf.range(shp[0]), tf.range(shp[1]), tf.range(shp[2]), tf.range(shp[2]), indexing="ij") + inds1 = tf.stack([mg[0],mg[1],mg[2]], axis=-1) + inds2 = tf.stack([mg[0],mg[1],mg[3]], axis=-1) res = tf.concat([ tf.gather_nd(A, inds1), tf.gather_nd(B, inds2)], axis=-1 @@ -651,7 +653,7 @@ def __init__(self, *args, **kwargs): self.kernel = kwargs.pop("kernel") self.node_message = kwargs.pop("node_message") self.hidden_dim = kwargs.pop("hidden_dim") - self.do_lsh = kwargs.pop("do_lsh") + self.do_lsh = kwargs.pop("do_lsh", True) self.activation = getattr(tf.keras.activations, kwargs.pop("activation")) self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation")) diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index 0e5cc101f..2b7781b75 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -63,12 +63,12 @@ setup: num_val_files: 10 dtype: float32 trainable: classification - classification_loss_type: categorical_cross_entropy #categorical_cross_entropy, sigmoid_focal_crossentropy + classification_loss_type: categorical_cross_entropy lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: - cls: none - charge: none + cls: inverse_sqrt + charge: signal_only pt: signal_only eta: signal_only sin_phi: signal_only @@ -83,27 +83,31 @@ parameters: dropout: 0.0 activation: gelu combined_graph_layer: - do_lsh: no - bin_size: 1600 - max_num_bins: 10 + do_lsh: yes + bin_size: 32 + max_num_bins: 500 distance_dim: 128 layernorm: no dropout: 0.0 dist_activation: linear kernel: - type: NodePairGaussianKernel - dist_mult: 0.1 - clip_value_low: 0.0 - num_node_messages: 2 + type: NodePairTrainableKernel + output_dim: 8 + hidden_dim: 32 + num_layers: 2 + activation: gelu node_message: - type: GHConvDense - output_dim: 512 + type: NodeMessageLearnable + output_dim: 256 + hidden_dim: 128 + num_layers: 2 activation: gelu - normalize_degrees: yes + aggregation_direction: dst + num_node_messages: 1 hidden_dim: 256 activation: gelu - num_graph_layers_common: 1 - num_graph_layers_energy: 1 + num_graph_layers_common: 2 + num_graph_layers_energy: 2 output_decoding: activation: gelu regression_use_classification: yes @@ -113,21 +117,21 @@ parameters: eta_skip_gate: yes phi_skip_gate: yes - id_dim_decrease: no + id_dim_decrease: yes charge_dim_decrease: yes pt_dim_decrease: yes eta_dim_decrease: yes phi_dim_decrease: yes energy_dim_decrease: yes - id_hidden_dim: 512 + id_hidden_dim: 256 charge_hidden_dim: 256 pt_hidden_dim: 256 eta_hidden_dim: 256 phi_hidden_dim: 256 energy_hidden_dim: 256 - id_num_layers: 3 + id_num_layers: 2 charge_num_layers: 2 pt_num_layers: 2 eta_num_layers: 2 @@ -144,6 +148,6 @@ timing: num_iter: 3 exponentialdecay: - decay_steps: 1000 - decay_rate: 0.98 + decay_steps: 2000 + decay_rate: 0.99 staircase: yes From dd285bb90d83b5980bd7bc54d8a39a86d273a350 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 14:46:43 +0300 Subject: [PATCH 08/17] epoch one-based Former-commit-id: 590a1007188b87e4e71d944792d04bbb9bc8de30 --- mlpf/tfmodel/model_setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index aef53ee6e..50be3662b 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -376,12 +376,15 @@ def plot_eff_and_fake_rate( def on_epoch_end(self, epoch, logs=None): + #first epoch is 1, not 0 + epoch = epoch + 1 + #save the training logs (losses) for this epoch with open("{}/history_{}.json".format(self.outpath, epoch), "w") as fi: json.dump(logs, fi) if self.plot_freq>1: - if (epoch+1)%self.plot_freq!=0 or epoch==0: + if epoch%self.plot_freq!=0 or epoch==1: return cp_dir = Path(self.outpath) / "epoch_{}".format(epoch) From 19cfd1b106c13b95fbd52cfcafa1057652f50ef4 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 15:17:44 +0300 Subject: [PATCH 09/17] optimization with cls only Former-commit-id: 32a32ec13b1091eeb108a5dfb2f984ab134b8382 --- parameters/test-gnn/cms-0l.yaml | 2 +- parameters/test-gnn/cms-lsh-1l.yaml | 2 +- parameters/test-gnn/cms-lsh-2l.yaml | 2 +- parameters/test-gnn/cms-lsh-3l.yaml | 2 +- parameters/test-gnn/cms-nolsh-1l.yaml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/parameters/test-gnn/cms-0l.yaml b/parameters/test-gnn/cms-0l.yaml index 3e230c7cb..5977abbc6 100644 --- a/parameters/test-gnn/cms-0l.yaml +++ b/parameters/test-gnn/cms-0l.yaml @@ -62,7 +62,7 @@ setup: num_epochs: 50 num_val_files: 20 dtype: float32 - trainable: + trainable: classification classification_loss_type: categorical_cross_entropy lr_schedule: exponentialdecay # exponentialdecay, onecycle diff --git a/parameters/test-gnn/cms-lsh-1l.yaml b/parameters/test-gnn/cms-lsh-1l.yaml index bdf62a034..c8c4dfb7e 100644 --- a/parameters/test-gnn/cms-lsh-1l.yaml +++ b/parameters/test-gnn/cms-lsh-1l.yaml @@ -62,7 +62,7 @@ setup: num_epochs: 50 num_val_files: 20 dtype: float32 - trainable: + trainable: classification classification_loss_type: categorical_cross_entropy lr_schedule: exponentialdecay # exponentialdecay, onecycle diff --git a/parameters/test-gnn/cms-lsh-2l.yaml b/parameters/test-gnn/cms-lsh-2l.yaml index 69320ceba..5eb0a83f2 100644 --- a/parameters/test-gnn/cms-lsh-2l.yaml +++ b/parameters/test-gnn/cms-lsh-2l.yaml @@ -62,7 +62,7 @@ setup: num_epochs: 50 num_val_files: 20 dtype: float32 - trainable: + trainable: classification classification_loss_type: categorical_cross_entropy lr_schedule: exponentialdecay # exponentialdecay, onecycle diff --git a/parameters/test-gnn/cms-lsh-3l.yaml b/parameters/test-gnn/cms-lsh-3l.yaml index 5cf0226c0..6ac8b76c7 100644 --- a/parameters/test-gnn/cms-lsh-3l.yaml +++ b/parameters/test-gnn/cms-lsh-3l.yaml @@ -62,7 +62,7 @@ setup: num_epochs: 50 num_val_files: 20 dtype: float32 - trainable: + trainable: classification classification_loss_type: categorical_cross_entropy lr_schedule: exponentialdecay # exponentialdecay, onecycle diff --git a/parameters/test-gnn/cms-nolsh-1l.yaml b/parameters/test-gnn/cms-nolsh-1l.yaml index edb43d666..697aac9ed 100644 --- a/parameters/test-gnn/cms-nolsh-1l.yaml +++ b/parameters/test-gnn/cms-nolsh-1l.yaml @@ -62,7 +62,7 @@ setup: num_epochs: 50 num_val_files: 20 dtype: float32 - trainable: + trainable: classification classification_loss_type: categorical_cross_entropy lr_schedule: exponentialdecay # exponentialdecay, onecycle From b40d8ebbbe40829cf6293ee0dbbb80ea7d456357 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 15:19:45 +0300 Subject: [PATCH 10/17] mpnn optimization Former-commit-id: 9a921ebeb394bfa36649896d2e546678c63f5c4a --- parameters/cms-dev.yaml | 4 +- parameters/test-gnn/cms-lsh-mpnn.yaml | 153 ++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 parameters/test-gnn/cms-lsh-mpnn.yaml diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index 2b7781b75..b9d603c52 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -23,7 +23,7 @@ dataset: padded_num_elem_size: 6400 #(pt, eta, sin phi, cos phi, E) num_momentum_outputs: 5 - classification_loss_coef: 100.0 + classification_loss_coef: 1.0 charge_loss_coef: 0.01 pt_loss_coef: 0.0001 eta_loss_coef: 100.0 @@ -56,7 +56,7 @@ setup: weights: weights_config: lr: 1e-3 - batch_size: 2 + batch_size: 4 num_events_train: 80000 num_events_test: 10000 num_epochs: 100 diff --git a/parameters/test-gnn/cms-lsh-mpnn.yaml b/parameters/test-gnn/cms-lsh-mpnn.yaml new file mode 100644 index 000000000..6b4ccc9ff --- /dev/null +++ b/parameters/test-gnn/cms-lsh-mpnn.yaml @@ -0,0 +1,153 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 4 + num_events_train: 1000 + num_events_test: 1000 + num_epochs: 50 + num_val_files: 20 + dtype: float32 + trainable: classification + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + do_node_encoding: no + hidden_dim: 128 + dropout: 0.0 + activation: gelu + combined_graph_layer: + do_lsh: yes + bin_size: 32 + max_num_bins: 500 + distance_dim: 128 + layernorm: no + dropout: 0.0 + dist_activation: linear + kernel: + type: NodePairTrainableKernel + output_dim: 8 + hidden_dim: 32 + num_layers: 2 + activation: gelu + node_message: + type: NodeMessageLearnable + output_dim: 256 + hidden_dim: 128 + num_layers: 2 + activation: gelu + aggregation_direction: dst + num_node_messages: 1 + hidden_dim: 256 + activation: gelu + num_graph_layers_common: 2 + num_graph_layers_energy: 2 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 256 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 2 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 1000 + decay_rate: 0.98 + staircase: yes From 93fcd50c647ad11e92ed4b634e28f11f54166e5f Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 16:36:47 +0300 Subject: [PATCH 11/17] up Former-commit-id: 557ccf625e54d128db4295502b53fac672ede523 --- mlpf/tfmodel/model.py | 8 +- notebooks/cms-mlpf.ipynb | 231 +++++++++++++++++--------- parameters/cms.yaml | 19 +-- parameters/test-gnn/cms-lsh-mpnn.yaml | 2 +- 4 files changed, 169 insertions(+), 91 deletions(-) diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index 08200745f..609952079 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -241,7 +241,13 @@ def __init__(self, *args, **kwargs): elif self.aggregation_direction == "src": self.agg_dim = -3 - self.ffn = point_wise_feed_forward_network(self.output_dim, self.hidden_dim, num_layers=self.num_layers, activation=self.activation, name=kwargs.get("name")+"_ffn") + self.ffn = point_wise_feed_forward_network( + self.output_dim, + self.hidden_dim, + num_layers=self.num_layers, + activation=self.activation, + name=kwargs.get("name")+"_ffn" + ) super(NodeMessageLearnable, self).__init__(*args, **kwargs) def call(self, inputs): diff --git a/notebooks/cms-mlpf.ipynb b/notebooks/cms-mlpf.ipynb index 4424b3211..bfdc07ec7 100644 --- a/notebooks/cms-mlpf.ipynb +++ b/notebooks/cms-mlpf.ipynb @@ -89,6 +89,7 @@ "outputs": [], "source": [ "pid_names = {\n", + " 0: \"no ptcl\",\n", " 1: \"ch.had\",\n", " 2: \"n.had\",\n", " 3: \"HFEM\",\n", @@ -99,6 +100,7 @@ "}\n", "\n", "pid_names_long = {\n", + " 0: \"no particle\",\n", " 1: \"charged hadrons\",\n", " 2: \"neutral hadrons\",\n", " 3: \"HFEM\",\n", @@ -120,7 +122,7 @@ "x_labels = [\n", " \"track\", \"PS1\", \"PS2\", \"ECAL\", \"HCAL\", \"GSF\", \"BREM\", \"HFEM\", \"HFHAD\", \"SC\", \"HO\"\n", "]\n", - "y_labels = [pid_names[i] for i in range(1,8)]" + "y_labels = [pid_names[i] for i in range(0,8)]" ] }, { @@ -130,7 +132,7 @@ "metadata": {}, "outputs": [], "source": [ - "path = \"../experiments/cms-dev_20210901_112919_500542.gpu0.local/evaluation/\"" + "path = \"../experiments/cms-dev_20210831_225815_541048.gpu0.local/evaluation/\"" ] }, { @@ -199,16 +201,6 @@ "ypred_id_f = ypred_id.flatten()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "corrected-tunisia", - "metadata": {}, - "outputs": [], - "source": [ - "np.unique(ypred_id[X[:, :, 0]==4], return_counts=True)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -236,8 +228,10 @@ "\n", "vals_sig = ypred_raw_f[energy_msk & elem_msk & (ycand_f[:, 0]==icls), icls]\n", "vals_bkg = ypred_raw_f[energy_msk & elem_msk & (ycand_f[:, 0]!=icls), icls]\n", - "hsig = np.histogram(vals_sig, bins=b)[0]\n", - "hbkg = np.histogram(vals_bkg, bins=b)[0]\n", + "\n", + "bins = np.linspace(0,1,100)\n", + "hsig = np.histogram(vals_sig, bins=bins)[0]\n", + "hbkg = np.histogram(vals_bkg, bins=bins)[0]\n", "\n", "a = np.cumsum(hsig)/np.sum(hsig)\n", "b = np.cumsum(hbkg)/np.sum(hbkg)\n", @@ -264,7 +258,9 @@ "cell_type": "code", "execution_count": null, "id": "virgin-nicaragua", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "for icls in range(1,8):\n", @@ -432,6 +428,26 @@ "plot_elem_energy_cls_prob(5)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "utility-beverage", + "metadata": {}, + "outputs": [], + "source": [ + "plot_elem_energy_cls_prob(8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "moderate-india", + "metadata": {}, + "outputs": [], + "source": [ + "plot_elem_energy_cls_prob(9)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -518,7 +534,7 @@ " ax1 = plt.subplot(2,1,1)\n", " mplhep.histplot(hist_cand, label=\"with PF candidate\")\n", " mplhep.histplot(hist_pred, label=\"with MLPF candidate\")\n", - " plt.legend()\n", + " plt.legend(frameon=False)\n", " plt.xlabel(xlabel)\n", " plt.ylabel(\"Number of particles\")\n", "\n", @@ -633,16 +649,21 @@ "metadata": {}, "outputs": [], "source": [ - "def loss_plot(train, test, margin=0.05):\n", + "def loss_plot(train, test, margin=0.05, smoothing=False):\n", " fig = plt.figure(figsize=(8,4))\n", " ax = plt.axes()\n", - " p0 = plt.plot(train, alpha=0.2)\n", - " p1 = plt.plot(test, alpha=0.2)\n", " \n", - " train_smooth = np.convolve(train, np.ones(5)/5, mode='valid')\n", - " plt.plot(train_smooth, color=p0[0].get_color(), lw=2, label=\"train\")\n", - " test_smooth = np.convolve(test, np.ones(5)/5, mode='valid')\n", - " plt.plot(test_smooth, color=p1[0].get_color(), lw=2, label=\"test\")\n", + " alpha = 0.2 if smoothing else 1.0\n", + " l0 = None if smoothing else \"train\"\n", + " l1 = None if smoothing else \"test\"\n", + " p0 = plt.plot(train, alpha=alpha, label=l0)\n", + " p1 = plt.plot(test, alpha=alpha, label=l1)\n", + " \n", + " if smoothing:\n", + " train_smooth = np.convolve(train, np.ones(5)/5, mode='valid')\n", + " plt.plot(train_smooth, color=p0[0].get_color(), lw=2, label=\"train\")\n", + " test_smooth = np.convolve(test, np.ones(5)/5, mode='valid')\n", + " plt.plot(test_smooth, color=p1[0].get_color(), lw=2, label=\"test\")\n", " \n", " plt.ylim(test[-1]*(1.0-margin), test[-1]*(1.0+margin))\n", " plt.legend(loc=\"best\", frameon=False)\n", @@ -658,7 +679,7 @@ "metadata": {}, "outputs": [], "source": [ - "p0 = loss_plot(history[\"loss\"].values, history[\"val_loss\"].values)\n", + "p0 = loss_plot(history[\"loss\"].values, history[\"val_loss\"].values, margin=0.02)\n", "plt.ylabel(\"Total loss\")\n", "plt.savefig(\"loss.pdf\", bbox_inches=\"tight\")" ] @@ -682,7 +703,7 @@ "metadata": {}, "outputs": [], "source": [ - "p0 = loss_plot(history[\"energy_loss\"].values, history[\"val_energy_loss\"].values, margin=0.05)\n", + "p0 = loss_plot(history[\"energy_loss\"].values, history[\"val_energy_loss\"].values, margin=0.01)\n", "plt.ylabel(\"Energy loss\")\n", "plt.savefig(\"energy_loss.pdf\", bbox_inches=\"tight\")" ] @@ -694,7 +715,7 @@ "metadata": {}, "outputs": [], "source": [ - "p0 = loss_plot(history[\"pt_loss\"].values, history[\"val_pt_loss\"].values, margin=0.1)\n", + "p0 = loss_plot(history[\"pt_loss\"].values, history[\"val_pt_loss\"].values, margin=0.02)\n", "plt.ylabel(\"$p_T$ loss\")\n", "plt.savefig(\"pt_loss.pdf\", bbox_inches=\"tight\")" ] @@ -706,7 +727,7 @@ "metadata": {}, "outputs": [], "source": [ - "p0 = loss_plot(history[\"sin_phi_loss\"].values, history[\"val_sin_phi_loss\"].values, margin=0.01)\n", + "p0 = loss_plot(history[\"sin_phi_loss\"].values, history[\"val_sin_phi_loss\"].values, margin=0.02)\n", "plt.ylabel(\"$\\sin \\phi$ loss\")\n", "plt.savefig(\"sin_phi_loss.pdf\", bbox_inches=\"tight\")" ] @@ -730,7 +751,7 @@ "metadata": {}, "outputs": [], "source": [ - "p0 = loss_plot(history[\"eta_loss\"].values, history[\"val_eta_loss\"].values, margin=0.01)\n", + "p0 = loss_plot(history[\"eta_loss\"].values, history[\"val_eta_loss\"].values, margin=0.005)\n", "plt.ylabel(\"$\\eta$ loss\")\n", "plt.savefig(\"eta_loss.pdf\", bbox_inches=\"tight\")" ] @@ -751,7 +772,9 @@ "cell_type": "code", "execution_count": null, "id": "august-feeding", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "for icls in range(1,8):\n", @@ -807,8 +830,8 @@ "\n", "cms_label(x1=0.18, x2=0.52, y=0.82)\n", "#sample_label(ax, x=0.8, y=1.0)\n", - "#plt.xticks(range(len(y_labels)), y_labels);\n", - "#plt.yticks(range(len(y_labels)), y_labels);\n", + "plt.xticks(range(len(y_labels)), y_labels);\n", + "plt.yticks(range(len(y_labels)), y_labels);\n", "plt.xlabel(\"Predicted PFCandidate\")\n", "plt.ylabel(\"True PFCandidate\")\n", "plt.title(\"MLPF trained on PF\", y=1.03)\n", @@ -830,8 +853,8 @@ "\n", "cms_label(x1=0.18, x2=0.52, y=0.82)\n", "#sample_label(ax, x=0.8, y=1.0)\n", - "#plt.xticks(range(len(y_labels)), y_labels);\n", - "#plt.yticks(range(len(y_labels)), y_labels);\n", + "plt.xticks(range(len(y_labels)), y_labels);\n", + "plt.yticks(range(len(y_labels)), y_labels);\n", "plt.xlabel(\"Predicted PFCandidate\")\n", "plt.ylabel(\"True PFCandidate\")\n", "plt.title(\"MLPF trained on PF\", y=1.03)\n", @@ -858,7 +881,9 @@ "cell_type": "code", "execution_count": null, "id": "expressed-samba", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "for icls in range(1,8):\n", @@ -881,64 +906,116 @@ { "cell_type": "code", "execution_count": null, - "id": "minor-beast", + "id": "paperback-timeline", "metadata": {}, "outputs": [], "source": [ - "fig, axes = plt.subplots(7, 6, figsize=(6*6,7*5))\n", - "\n", - "for axs, icls in zip(axes, range(1,8)): \n", - " axes = axs.flatten()\n", + "def plot_particle_regression(ivar=6, icls=2, particle_label=\"Neutral hadrons\", log=True, minval=-1, maxval=3, norm=matplotlib.colors.LogNorm()):\n", + " plt.figure(figsize=(6,5))\n", + " ax = plt.axes()\n", " \n", - " npred = np.sum(ypred_id == icls, axis=1)\n", - " ncand = np.sum(ycand[:, :, 0] == icls, axis=1)\n", - " ngen = np.sum(ygen[:, :, 0] == icls, axis=1)\n", " \n", - " a = 0.5*min(np.min(npred), np.min(ncand))\n", - " b = 1.5*max(np.max(npred), np.max(ncand))\n", + " bins = np.linspace(minval, maxval, 100)\n", + " msk_both = (ypred_id_f == icls) & (ycand_f[:, 0]==icls)\n", " \n", - " axes[0].scatter(ncand, npred, marker=\".\")\n", + " vals_true = ycand_f[msk_both, ivar]\n", + " vals_pred = ypred_f[msk_both, ivar]\n", " \n", - " axes[0].set_xlim(a,b)\n", - " axes[0].set_ylim(a,b)\n", - " axes[0].plot([a,b],[a,b], color=\"black\", ls=\"--\")\n", - " axes[0].set_title(pid_names[icls])\n", - " axes[0].set_xlabel(\"number of PFCandidates\")\n", - " axes[0].set_ylabel(\"number of MLPFCandidates\")\n", + " if log:\n", + " vals_true = np.log10(vals_true)\n", + " vals_pred = np.log10(vals_pred)\n", " \n", - " msk_both = (ycand_f[:, 0]==icls) & (ypred_id_f==icls)\n", - " print(icls, np.sum(msk_both))\n", - "\n", - " for ivar, ax in zip([2,3,4,5,6], axes[1:]):\n", - " \n", - "# hist = np.histogram2d(\n", - "# ycand_f[msk_both, ivar],\n", - "# ypred_f[msk_both, ivar], bins=(bins[ivar], bins[ivar])\n", - "# )\n", - "# norm = matplotlib.colors.Normalize(vmin=0, vmax=max(10, np.max(hist[0])))\n", - "# if ivar == 2 or ivar == 6:\n", - "# norm = matplotlib.colors.LogNorm(vmin=1, vmax=max(10, 10*np.max(hist[0])))\n", - "# hep.hist2dplot(\n", - "# hist, cmap=\"Blues\",\n", - "# norm=norm,\n", - "# ax=ax\n", - "# )\n", - " ax.scatter(ycand_f[msk_both, ivar], ypred_f[msk_both, ivar], marker=\".\", alpha=0.2)\n", - " ax.plot([bins[ivar][0],bins[ivar][-1]], [bins[ivar][0], bins[ivar][-1]], color=\"black\", ls=\"--\")\n", - " ax.set_title(\"pred. {}, {}\".format(pid_names[icls], var_names[ivar]))\n", - " ax.set_xlabel(\"true value (PFCandidate)\")\n", - " ax.set_ylabel(\"reconstructed value (MLPF)\")\n", - "plt.tight_layout()\n", - "plt.savefig(\"full_performance.png\", bbox_inches=\"tight\", dpi=400)" + " plt.hist2d(\n", + " vals_true,\n", + " vals_pred,\n", + " bins=(bins, bins),\n", + " cmap=\"Blues\", norm=norm\n", + " )\n", + " \n", + " plt.colorbar()\n", + " plt.plot([minval, maxval], [minval, maxval], color=\"black\", ls=\"--\", lw=0.5)\n", + " plt.xlim(minval, maxval)\n", + " plt.ylim(minval, maxval)\n", + " cms_label(x1=0.2, x2=0.48)\n", + " plt.text(0.02, 0.95, particle_label, transform=ax.transAxes)\n", + " ax.set_xticks(ax.get_yticks());" ] }, { "cell_type": "code", "execution_count": null, - "id": "scheduled-worst", + "id": "ecological-toner", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "plot_particle_regression(ivar=6, icls=1, particle_label=\"Charged hadrons\")\n", + "plt.xlabel(\"PFCandidate $\\log_{10}$ E/GeV\")\n", + "plt.ylabel(\"MLPFCandidate $\\log_{10}$ E/GeV\")\n", + "plt.savefig(\"energy_corr_cls1.pdf\", bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "transparent-remedy", + "metadata": {}, + "outputs": [], + "source": [ + "plot_particle_regression(ivar=6, icls=2, particle_label=\"Neutral hadrons\")\n", + "plt.xlabel(\"PFCandidate $\\log_{10}$ E/GeV\")\n", + "plt.ylabel(\"MLPFCandidate $\\log_{10}$ E/GeV\")\n", + "plt.savefig(\"energy_corr_cls2.pdf\", bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "promotional-checklist", + "metadata": {}, + "outputs": [], + "source": [ + "plot_particle_regression(ivar=3, icls=1, particle_label=\"Charged hadrons\", log=False, minval=-4, maxval=4, norm=None)\n", + "plt.xlabel(\"PFCandidate $\\eta$\")\n", + "plt.ylabel(\"MLPFCandidate $\\eta$\")\n", + "plt.savefig(\"eta_corr_cls1.pdf\", bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "suitable-kansas", + "metadata": {}, + "outputs": [], + "source": [ + "plot_particle_regression(ivar=3, icls=2, particle_label=\"Neutral hadrons\", log=False, minval=-4, maxval=4, norm=None)\n", + "plt.xlabel(\"PFCandidate $\\eta$\")\n", + "plt.ylabel(\"MLPFCandidate $\\eta$\")\n", + "plt.savefig(\"eta_corr_cls2.pdf\", bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "restricted-million", + "metadata": {}, + "outputs": [], + "source": [ + "plot_particle_regression(ivar=6, icls=3, particle_label=\"HF\", minval=0.0, maxval=4, norm=None)\n", + "plt.xlabel(\"PFCandidate $\\log_{10}$ E/GeV\")\n", + "plt.ylabel(\"MLPFCandidate $\\log_{10}$ E/GeV\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "raising-first", + "metadata": {}, + "outputs": [], + "source": [ + "plot_particle_regression(ivar=6, icls=4, particle_label=\"HF\", minval=0.0, maxval=4, norm=None)\n", + "plt.xlabel(\"PFCandidate $\\log_{10}$ E/GeV\")\n", + "plt.ylabel(\"MLPFCandidate $\\log_{10}$ E/GeV\")" + ] } ], "metadata": { diff --git a/parameters/cms.yaml b/parameters/cms.yaml index 61e0b6ba3..818762c9c 100644 --- a/parameters/cms.yaml +++ b/parameters/cms.yaml @@ -59,7 +59,7 @@ setup: batch_size: 5 num_events_train: 80000 num_events_test: 10000 - num_epochs: 100 + num_epochs: 50 num_val_files: 10 dtype: float32 trainable: @@ -78,18 +78,13 @@ sample_weights: parameters: model: gnn_dense input_encoding: cms - do_node_encoding: no - hidden_dim: 128 - dropout: 0.0 - activation: gelu combined_graph_layer: - do_lsh: yes - bin_size: 160 + bin_size: 640 max_num_bins: 100 distance_dim: 128 layernorm: no dropout: 0.0 - dist_activation: gelu + dist_activation: linear kernel: type: NodePairGaussianKernel dist_mult: 0.1 @@ -97,7 +92,7 @@ parameters: num_node_messages: 1 node_message: type: GHConvDense - output_dim: 128 + output_dim: 256 activation: gelu normalize_degrees: yes hidden_dim: 128 @@ -127,7 +122,7 @@ parameters: phi_hidden_dim: 256 energy_hidden_dim: 256 - id_num_layers: 2 + id_num_layers: 3 charge_num_layers: 2 pt_num_layers: 2 eta_num_layers: 2 @@ -145,5 +140,5 @@ timing: exponentialdecay: decay_steps: 1000 - decay_rate: 0.98 - staircase: yes + decay_rate: 0.99 + staircase: yes \ No newline at end of file diff --git a/parameters/test-gnn/cms-lsh-mpnn.yaml b/parameters/test-gnn/cms-lsh-mpnn.yaml index 6b4ccc9ff..291cd98a5 100644 --- a/parameters/test-gnn/cms-lsh-mpnn.yaml +++ b/parameters/test-gnn/cms-lsh-mpnn.yaml @@ -102,7 +102,7 @@ parameters: hidden_dim: 128 num_layers: 2 activation: gelu - aggregation_direction: dst + aggregation_direction: src num_node_messages: 1 hidden_dim: 256 activation: gelu From e0e8fee5bc3cae996569ccdbba732a03bec2c96f Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 2 Sep 2021 17:20:36 +0300 Subject: [PATCH 12/17] up Former-commit-id: df8fc6fb8f34bb9f13cd0778916c233ab3daaa73 --- parameters/cms-dev.yaml | 6 +++--- parameters/cms.yaml | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/parameters/cms-dev.yaml b/parameters/cms-dev.yaml index b9d603c52..0a1b2f203 100644 --- a/parameters/cms-dev.yaml +++ b/parameters/cms-dev.yaml @@ -56,7 +56,7 @@ setup: weights: weights_config: lr: 1e-3 - batch_size: 4 + batch_size: 2 num_events_train: 80000 num_events_test: 10000 num_epochs: 100 @@ -84,8 +84,8 @@ parameters: activation: gelu combined_graph_layer: do_lsh: yes - bin_size: 32 - max_num_bins: 500 + bin_size: 128 + max_num_bins: 100 distance_dim: 128 layernorm: no dropout: 0.0 diff --git a/parameters/cms.yaml b/parameters/cms.yaml index 818762c9c..fedd887cf 100644 --- a/parameters/cms.yaml +++ b/parameters/cms.yaml @@ -82,7 +82,7 @@ parameters: bin_size: 640 max_num_bins: 100 distance_dim: 128 - layernorm: no + layernorm: yes dropout: 0.0 dist_activation: linear kernel: @@ -92,7 +92,7 @@ parameters: num_node_messages: 1 node_message: type: GHConvDense - output_dim: 256 + output_dim: 128 activation: gelu normalize_degrees: yes hidden_dim: 128 @@ -115,7 +115,7 @@ parameters: phi_dim_decrease: yes energy_dim_decrease: yes - id_hidden_dim: 256 + id_hidden_dim: 512 charge_hidden_dim: 256 pt_hidden_dim: 256 eta_hidden_dim: 256 @@ -128,7 +128,7 @@ parameters: eta_num_layers: 2 phi_num_layers: 2 energy_num_layers: 2 - layernorm: no + layernorm: yes mask_reg_cls0: no skip_connection: yes @@ -139,6 +139,6 @@ timing: num_iter: 3 exponentialdecay: - decay_steps: 1000 + decay_steps: 2000 decay_rate: 0.99 staircase: yes \ No newline at end of file From d2cadf742fc38903e5ddb4ca1a6bd1e959570c27 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Fri, 3 Sep 2021 13:55:30 +0300 Subject: [PATCH 13/17] add charge to clic pf candidate Former-commit-id: a7985ea8c8c6895b728bcc96aaa5970c1540d810 --- clic/dumper.py | 6 +- notebooks/clic.ipynb | 192 ++++++++++++++++++++++++++++++------------- 2 files changed, 139 insertions(+), 59 deletions(-) diff --git a/clic/dumper.py b/clic/dumper.py index 15307f354..6e1b523c4 100644 --- a/clic/dumper.py +++ b/clic/dumper.py @@ -61,7 +61,8 @@ def pfParticleToDict(par): "px": mom[0], "py": mom[1], "pz": mom[2], - "energy": par.getEnergy() + "energy": par.getEnergy(), + "charge": par.getCharge() } return vec @@ -210,6 +211,7 @@ def caloHitToDict(par, calohit_to_cluster, genparticle_dict, calohit_recotosim): nPF=colPF.getNumberOfElements() nCl=colCl.getNumberOfElements() nTr=colTr.getNumberOfElements() + nHit=simTrackHits.getNumberOfElements() nHCB=colHCB.getNumberOfElements() nHCE=colHCE.getNumberOfElements() nECB=colECB.getNumberOfElements() @@ -223,7 +225,7 @@ def caloHitToDict(par, calohit_to_cluster, genparticle_dict, calohit_recotosim): assert(not (recohit in calohit_recotosim)) calohit_recotosim[recohit] = simhit - print "Event %d, nGen=%d, nPF=%d, nClusters=%d, nTracks=%d, nHCAL=%d, nECAL=%d" % (nEvent, nMc, nPF, nCl, nTr, nHCB+nHCE, nECB+nECE) + print "Event %d, nGen=%d, nPF=%d, nClusters=%d, nTracks=%d, nHCAL=%d, nECAL=%d, nHits=%d" % (nEvent, nMc, nPF, nCl, nTr, nHCB+nHCE, nECB+nECE, nHit) genparticles = [] genparticle_dict = {} diff --git a/notebooks/clic.ipynb b/notebooks/clic.ipynb index 8d52cbf8a..3f79ffd8e 100644 --- a/notebooks/clic.ipynb +++ b/notebooks/clic.ipynb @@ -22,7 +22,8 @@ "metadata": {}, "outputs": [], "source": [ - "data = json.load(bz2.BZ2File(\"/home/joosep/Downloads/pythia6_ttbar_0001_pandora.json.bz2\", \"r\"))" + "#data = json.load(bz2.BZ2File(\"/home/joosep/Downloads/pythia6_ttbar_0001_pandora.json.bz2\", \"r\"))\n", + "data = json.load(bz2.BZ2File(\"/home/joosep/particleflow/data/clic/gev380ee_pythia6_ttbar_rfull201/raw/pythia6_ttbar_0001_pandora_0.json.bz2\", \"r\"))" ] }, { @@ -55,7 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "iev = 28\n", + "iev = 0\n", "df_gen = pandas.DataFrame(data[iev][\"genparticles\"])\n", "\n", "df_hit = pandas.DataFrame(data[iev][\"track_hits\"])\n", @@ -71,6 +72,46 @@ "df_tr[\"pz\"] = df_tr[\"tan_lambda\"]*df_tr[\"pt\"]" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6cc1ff5", + "metadata": {}, + "outputs": [], + "source": [ + "df_hit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9089cfae", + "metadata": {}, + "outputs": [], + "source": [ + "df_ecal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e01940", + "metadata": {}, + "outputs": [], + "source": [ + "df_hcal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efc9be54", + "metadata": {}, + "outputs": [], + "source": [ + "df_gen" + ] + }, { "cell_type": "code", "execution_count": null, @@ -150,6 +191,9 @@ " if filter_gp(gp):\n", " dg.add_node((\"gp\", gp))\n", " gps.add(gp)\n", + " \n", + " #the track is added to the genparticle with a very high weight\n", + " #because we always want to associate the genparticle to a track if it's possible\n", " dg.add_edge((\"gp\", gp), (\"tr\", itr), weight=9999.0)\n", "\n", " \n", @@ -157,22 +201,26 @@ "gps = set(gps)\n", "\n", "#now loop over all the genparticles\n", - "#for each genparticle, find the neighboring reco elements (clusters and tracks)\n", - "#sort the neighbors by the edge weight (deposited energy)\n", - "#for each genparticle, choose the closest neighbor as the \"key\" reco element\n", - "#remove the reco element from the list\n", "pairs = {}\n", "for gp in gps:\n", " gp_node = (\"gp\", gp)\n", + "\n", + " #find the neighboring reco elements (clusters and tracks)\n", " neighbors = list(dg.neighbors(gp_node))\n", " weights = [dg.edges[gp_node, n][\"weight\"] for n in neighbors]\n", " nw = zip(neighbors, weights)\n", + " \n", + " #sort the neighbors by the edge weight (deposited energy)\n", " nw = sorted(nw, key=lambda x: x[1], reverse=True)\n", " reco_obj = None\n", " if len(nw)>0:\n", + " #choose the closest neighbor as the \"key\" reco element\n", " reco_obj = nw[0][0]\n", - " dg.remove_node(reco_obj)\n", " \n", + " #remove the reco element from the list, so it can't be associated to anything else\n", + " dg.remove_node(reco_obj)\n", + " \n", + " #this genparticle had a unique reco element\n", " if reco_obj:\n", " pf_obj = None\n", " if reco_obj and reco_obj in reco_to_pf:\n", @@ -180,8 +228,11 @@ "\n", " assert(not (reco_obj in pairs))\n", " pairs[reco_obj] = (gp, pf_obj)\n", + " \n", + " #this is a case where a genparticle did not have a key reco element, but instead was smeared between others\n", " else:\n", - " print(\"genparticle {} is merged and cannot be reconstructed\".format(gp))" + " print(\"genparticle {} is merged and cannot be reconstructed\".format(gp))\n", + " print(df_gen.loc[gp])" ] }, { @@ -201,27 +252,27 @@ "metadata": {}, "outputs": [], "source": [ - "def track_as_array(itr):\n", + "def track_as_array(df_tr, itr):\n", " row = df_tr.loc[itr]\n", " return [0, row[\"px\"], row[\"py\"], row[\"pz\"], row[\"nhits\"], row[\"d0\"], row[\"z0\"]]\n", "\n", - "def cluster_as_array(icl):\n", + "def cluster_as_array(df_cl, icl):\n", " row = df_cl.loc[icl]\n", " return [1, row[\"x\"], row[\"y\"], row[\"z\"], row[\"nhits_ecal\"], row[\"nhits_hcal\"], 0.0]\n", "\n", - "def gen_as_array(igen):\n", + "def gen_as_array(df_gen, igen):\n", " if igen:\n", " row = df_gen.loc[igen]\n", - " return np.array([row[\"pdgid\"], row[\"px\"], row[\"py\"], row[\"pz\"], row[\"energy\"]])\n", + " return np.array([abs(row[\"pdgid\"]), row[\"charge\"], row[\"px\"], row[\"py\"], row[\"pz\"], row[\"energy\"]])\n", " else:\n", - " return np.zeros(5)\n", + " return np.zeros(6)\n", " \n", - "def pf_as_array(igen):\n", + "def pf_as_array(df_pfs, igen):\n", " if igen:\n", " row = df_pfs.loc[igen]\n", - " return np.array([row[\"type\"], row[\"px\"], row[\"py\"], row[\"pz\"], row[\"energy\"]])\n", + " return np.array([abs(row[\"type\"]), row[\"charge\"], row[\"px\"], row[\"py\"], row[\"pz\"], row[\"energy\"]])\n", " else:\n", - " return np.zeros(5)" + " return np.zeros(6)" ] }, { @@ -231,37 +282,42 @@ "metadata": {}, "outputs": [], "source": [ - "Xs = []\n", - "ys_gen = []\n", - "ys_cand = []\n", - "for itr in range(len(df_tr)):\n", - " Xs.append(track_as_array(itr))\n", + "def flatten_event(df_tr, df_cl, df_gen, df_pfs, pairs):\n", + " Xs = []\n", + " ys_gen = []\n", + " ys_cand = []\n", " \n", - " k = (\"tr\", itr)\n", - " gp = None\n", - " rp = None\n", - " if k in pairs:\n", - " gp = pairs[k][0]\n", - " rp = pairs[k][1]\n", - " ys_gen.append(gen_as_array(gp))\n", - " ys_cand.append(pf_as_array(rp))\n", + " #find all track-associated particles\n", + " for itr in range(len(df_tr)):\n", + " Xs.append(track_as_array(df_tr, itr))\n", "\n", + " k = (\"tr\", itr)\n", + " gp = None\n", + " rp = None\n", + " if k in pairs:\n", + " gp = pairs[k][0]\n", + " rp = pairs[k][1]\n", + " ys_gen.append(gen_as_array(df_gen, gp))\n", + " ys_cand.append(pf_as_array(df_pfs, rp))\n", " \n", - "for icl in range(len(df_cl)):\n", - " Xs.append(cluster_as_array(icl))\n", - " \n", - " k = (\"cl\", icl)\n", - " gp = None\n", - " rp = None\n", - " if k in pairs:\n", - " gp = pairs[k][0]\n", - " rp = pairs[k][1]\n", - " ys_gen.append(gen_as_array(gp))\n", - " ys_cand.append(pf_as_array(rp))\n", + " #find all cluster-associated particles\n", + " for icl in range(len(df_cl)):\n", + " Xs.append(cluster_as_array(df_cl, icl))\n", + "\n", + " k = (\"cl\", icl)\n", + " gp = None\n", + " rp = None\n", + " if k in pairs:\n", + " gp = pairs[k][0]\n", + " rp = pairs[k][1]\n", + " ys_gen.append(gen_as_array(df_gen, gp))\n", + " ys_cand.append(pf_as_array(df_pfs, rp))\n", + "\n", + " Xs = np.stack(Xs, axis=-1).T\n", + " ys_gen = np.stack(ys_gen, axis=-1).T\n", + " ys_cand = np.stack(ys_cand, axis=-1).T\n", " \n", - "Xs = np.stack(Xs, axis=-1).T\n", - "ys_gen = np.stack(ys_gen, axis=-1).T\n", - "ys_cand = np.stack(ys_cand, axis=-1).T" + " return Xs, ys_gen, ys_cand" ] }, { @@ -271,58 +327,80 @@ "metadata": {}, "outputs": [], "source": [ - "len(Xs)\n", - "i = 106" + "Xs, ys_gen, ys_cand = flatten_event(df_tr, df_cl, df_gen, df_pfs, pairs)\n", + "len(Xs), len(ys_gen), len(ys_cand)" ] }, { "cell_type": "code", "execution_count": null, - "id": "mexican-immune", + "id": "c022fce0", "metadata": {}, "outputs": [], "source": [ - "Xs[i]" + "import sklearn\n", + "import sklearn.metrics" ] }, { "cell_type": "code", "execution_count": null, - "id": "fossil-cornell", + "id": "16dde9e2", "metadata": {}, "outputs": [], "source": [ - "ys_gen[i]" + "np.unique(ys_gen[:, 0])" ] }, { "cell_type": "code", "execution_count": null, - "id": "medium-armor", + "id": "012ef075", "metadata": {}, "outputs": [], "source": [ - "ys_cand[i]" + "np.unique(ys_cand[:, 0])" ] }, { "cell_type": "code", "execution_count": null, - "id": "confident-publisher", + "id": "e9c5b8cd", "metadata": {}, "outputs": [], "source": [ - "ys_gen[:, 0]" + "labels = [0, 13, 11, 22, 130, 211, 321, 2112, 2212]\n", + "labels_text = {\n", + " 0: \"none\",\n", + " 13: \"mu\",\n", + " 11: \"el\",\n", + " 22: \"$\\gamma$\",\n", + " 130: \"$K^0_L$\",\n", + " 211: \"$\\pi^\\pm$\",\n", + " 321: \"$K^+$\",\n", + " 2112: \"n\",\n", + " 2212: \"p\"\n", + "}\n", + "cm = sklearn.metrics.confusion_matrix(\n", + " ys_gen[:, 0],\n", + " ys_cand[:, 0],\n", + " labels=labels,\n", + " normalize=\"true\"\n", + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "cardiovascular-majority", + "id": "8817f3e5", "metadata": {}, "outputs": [], "source": [ - "ys_cand[:, 0]" + "plt.imshow(cm, cmap=\"Blues\")\n", + "plt.xticks(range(len(labels)), [labels_text[l] for l in labels], rotation=90);\n", + "plt.yticks(range(len(labels)), [labels_text[l] for l in labels]);\n", + "plt.xlabel(\"reco\")\n", + "plt.ylabel(\"gen\")" ] }, { @@ -461,7 +539,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -475,7 +553,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" } }, "nbformat": 4, From 42119faca677861e1e9013cf5205b3d5081a2f5b Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Fri, 3 Sep 2021 13:56:46 +0300 Subject: [PATCH 14/17] added cms gen config Former-commit-id: f2d75e0f602360caad2c0c9cd822005f379983d8 --- mlpf/pipeline.py | 7 +- notebooks/cms-mlpf.ipynb | 106 +++++++++++++++++++++++++++- parameters/cms-gen.yaml | 144 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 253 insertions(+), 4 deletions(-) create mode 100644 parameters/cms-gen.yaml diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py index 5bddd912b..c7859d4cb 100644 --- a/mlpf/pipeline.py +++ b/mlpf/pipeline.py @@ -190,9 +190,10 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq, configure_model_weights(model, config["setup"]["trainable"]) model(tf.cast(X_val[:1], model_dtype)) - print("trainable weights") - for w in model.trainable_weights: - print(w.name) + print("model weights") + tw_names = [m.name for m in model.trainable_weights] + for w in model.weights: + print("layer={} trainable={} shape={} num_weights={}".format(w.name, w.name in tw_names, w.shape, np.prod(w.shape))) loss_dict, loss_weights = get_loss_dict(config) model.compile( diff --git a/notebooks/cms-mlpf.ipynb b/notebooks/cms-mlpf.ipynb index bfdc07ec7..d2b76df1e 100644 --- a/notebooks/cms-mlpf.ipynb +++ b/notebooks/cms-mlpf.ipynb @@ -132,7 +132,8 @@ "metadata": {}, "outputs": [], "source": [ - "path = \"../experiments/cms-dev_20210831_225815_541048.gpu0.local/evaluation/\"" + "#path = \"../experiments/cms-dev_20210831_225815_541048.gpu0.local/evaluation/\"\n", + "path = \"../experiments/cms-gen_20210903_114315_805349.joosep-desktop-work/evaluation/\"" ] }, { @@ -1016,6 +1017,109 @@ "plt.xlabel(\"PFCandidate $\\log_{10}$ E/GeV\")\n", "plt.ylabel(\"MLPFCandidate $\\log_{10}$ E/GeV\")" ] + }, + { + "cell_type": "markdown", + "id": "4a3ab75a", + "metadata": {}, + "source": [ + "## Gen level" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "700c7700", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "for icls in range(1,8):\n", + " npred = np.sum(ypred_id == icls, axis=1)\n", + " ncand = np.sum(ycand[:, :, 0] == icls, axis=1)\n", + " ngen = np.sum(ygen[:, :, 0] == icls, axis=1)\n", + " plt.figure(figsize=(6,6))\n", + " plt.scatter(ngen, ncand, marker=\".\", alpha=0.5, label=\"PF\")\n", + " plt.scatter(ngen, npred, marker=\".\", alpha=0.5, label=\"MLPF\")\n", + " plt.legend(loc=\"best\", frameon=False)\n", + " a = 0.5*min(np.min(ngen), np.min(ngen))\n", + " b = 2*max(np.max(ngen), np.max(ngen))\n", + " plt.xlim(a,b)\n", + " plt.ylim(a,b)\n", + " plt.plot([a,b],[a,b], color=\"black\", ls=\"--\")\n", + " plt.title(pid_names_long[icls],y=1.05)\n", + " plt.xlabel(\"number of gen particles\")\n", + " plt.ylabel(\"number of PFCandidates\")\n", + " cms_label(x2=0.6, y=0.89)\n", + "# plt.savefig(\"num_cls{}.pdf\".format(icls))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5661ff16", + "metadata": {}, + "outputs": [], + "source": [ + "bins = np.linspace(0,500,100)\n", + "mplhep.histplot(np.histogram(ygen_f[ygen_f[:, 0]==2, 6], bins=bins))\n", + "mplhep.histplot(np.histogram(ycand_f[ycand_f[:, 0]==2, 6], bins=bins))\n", + "mplhep.histplot(np.histogram(ypred_f[ypred_f[:, 0]==2, 6], bins=bins))\n", + "plt.yscale(\"log\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82f29ef8", + "metadata": {}, + "outputs": [], + "source": [ + "icls = 4\n", + "bins = np.linspace(-200,200,100)\n", + "particle_label = \"neutral hadrons\"\n", + "\n", + "msk_cand = (ygen_f[:, 0]==icls) & (ycand_f[:, 0]==icls)\n", + "msk_pred = (ygen_f[:, 0]==icls) & (ypred_f[:, 0]==icls)\n", + "\n", + "vals_gen1 = ygen_f[msk_cand, 6]\n", + "vals_gen2 = ygen_f[msk_pred, 6]\n", + "vals_cand = ycand_f[msk_cand, 6]\n", + "vals_pred = ypred_f[msk_pred, 6]\n", + "\n", + "res_cand = vals_gen1 - vals_cand\n", + "res_pred = vals_gen2 - vals_pred\n", + "\n", + "plt.figure(figsize=(5,5))\n", + "ax = plt.axes()\n", + "plt.hist(\n", + " res_cand,\n", + " bins=bins, histtype=\"step\", lw=2,\n", + " label=\"PF, $\\mu={:.2f}, \\sigma={:.2f}$\".format(np.mean(res_cand), np.std(res_cand)));\n", + "\n", + "plt.hist(res_pred,\n", + " bins=bins,\n", + " histtype=\"step\", lw=2,\n", + " label=\"MLPF, $\\mu={:.2f}, \\sigma={:.2f}$\".format(np.mean(res_pred), np.std(res_pred))\n", + ");\n", + "\n", + "plt.yscale(\"log\")\n", + "plt.ylabel(\"Number of particles / bin\")\n", + "cms_label(x1=0.21, x2=0.55)\n", + "plt.ylim(top=10**9)\n", + "plt.text(0.02, 0.95, particle_label, transform=ax.transAxes)\n", + "plt.xlabel(\"particle $E_{\\mathrm{gen}} - E_{\\mathrm{reco}}$ [GeV]\")\n", + "plt.legend(frameon=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "338f50e9", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/parameters/cms-gen.yaml b/parameters/cms-gen.yaml new file mode 100644 index 000000000..f7df7746b --- /dev/null +++ b/parameters/cms-gen.yaml @@ -0,0 +1,144 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: gen + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + classification_loss_coef: 1.0 + charge_loss_coef: 0.01 + pt_loss_coef: 0.0001 + eta_loss_coef: 100.0 + sin_phi_loss_coef: 10.0 + cos_phi_loss_coef: 10.0 + energy_loss_coef: 0.0001 + raw_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl* + processed_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_gen/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl* + energy_loss: + type: Huber + pt_loss: + type: Huber + sin_phi_loss: + type: Huber + delta: 0.1 + cos_phi_loss: + type: Huber + delta: 0.1 + eta_loss: + type: Huber + delta: 0.1 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: + lr: 1e-3 + batch_size: 4 + num_events_train: 80000 + num_events_test: 10000 + num_epochs: 50 + num_val_files: 10 + dtype: float32 + trainable: classification + classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + input_encoding: cms + combined_graph_layer: + bin_size: 640 + max_num_bins: 100 + distance_dim: 64 + layernorm: no + dropout: 0.0 + dist_activation: linear + kernel: + type: NodePairGaussianKernel + dist_mult: 0.1 + clip_value_low: 0.0 + num_node_messages: 2 + node_message: + type: GHConvDense + output_dim: 512 + activation: gelu + normalize_degrees: yes + hidden_dim: 512 + activation: gelu + num_graph_layers_common: 2 + num_graph_layers_energy: 2 + output_decoding: + activation: gelu + regression_use_classification: yes + dropout: 0.0 + + pt_skip_gate: no + eta_skip_gate: yes + phi_skip_gate: yes + + id_dim_decrease: yes + charge_dim_decrease: yes + pt_dim_decrease: yes + eta_dim_decrease: yes + phi_dim_decrease: yes + energy_dim_decrease: yes + + id_hidden_dim: 512 + charge_hidden_dim: 256 + pt_hidden_dim: 256 + eta_hidden_dim: 256 + phi_hidden_dim: 256 + energy_hidden_dim: 256 + + id_num_layers: 3 + charge_num_layers: 2 + pt_num_layers: 2 + eta_num_layers: 2 + phi_num_layers: 2 + energy_num_layers: 2 + layernorm: no + mask_reg_cls0: no + + skip_connection: yes + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 2000 + decay_rate: 0.99 + staircase: yes From 5370d8d487bc6a29946c80ee66943e8d5df63624 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Fri, 3 Sep 2021 16:00:58 +0300 Subject: [PATCH 15/17] fix bin size Former-commit-id: 1519195be4e246d939f72e2b582faa786e6bc12b --- scripts/test_load_tfmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test_load_tfmodel.py b/scripts/test_load_tfmodel.py index 76020bc3d..2e30b6346 100644 --- a/scripts/test_load_tfmodel.py +++ b/scripts/test_load_tfmodel.py @@ -2,7 +2,7 @@ import sys import numpy as np -bin_size = 160 +bin_size = 640 num_features = 15 def load_graph(frozen_graph_filename): From d744d1311500bc2980789d3ef66531d30925a97b Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Fri, 3 Sep 2021 16:01:03 +0300 Subject: [PATCH 16/17] fixes for gun sample training Former-commit-id: 47ae9a46aef6d16e4b6619b6bc4a859967e6d5ec --- mlpf/pipeline.py | 29 +++++++++++++++++++---------- mlpf/tfmodel/model.py | 5 ++--- mlpf/tfmodel/model_setup.py | 1 - parameters/delphes.yaml | 1 + 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py index 348666f71..2daa8d642 100644 --- a/mlpf/pipeline.py +++ b/mlpf/pipeline.py @@ -149,6 +149,11 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq, prefix += customize + "_" config = customization_functions[customize](config) + if recreate or (weights is None): + outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node()) + else: + outdir = str(Path(weights).parent) + # Decide tf.distribute.strategy depending on number of available GPUs strategy, maybe_global_batch_size = get_strategy(global_batch_size) if "CPU" not in strategy.extended.worker_devices[0]: @@ -166,10 +171,6 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq, X_val, ygen_val, ycand_val = prepare_val_data(config, dataset_def, single_file=False) - if recreate or (weights is None): - outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node()) - else: - outdir = str(Path(weights).parent) if experiment: experiment.set_name(outdir) experiment.log_code("mlpf/tfmodel/model.py") @@ -196,7 +197,11 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq, # Run model once to build the layers print(X_val.shape) - model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"])) + + if config["tensorflow"]["eager"]: + model(X_val[:1]) + else: + model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"])) initial_epoch = 0 if weights: @@ -402,16 +407,19 @@ def find_lr(config, outdir, figname, logscale): def customize_gun_sample(config): - config["dataset"]["padded_num_elem_size"] = 640 + + #FIXME: must be at least 2x bin_size + config["dataset"]["padded_num_elem_size"] = 1280 + config["dataset"]["processed_path"] = "data/SinglePiFlatPt0p7To10_cfi/tfr_cand/*.tfrecords" - config["dataset"]["raw_path"] = "data/SinglePiFlatPt0p7To10_cfi/raw/*.pkl.bz2" + config["dataset"]["raw_path"] = "data/SinglePiFlatPt0p7To10_cfi/raw/*.pkl*" config["dataset"]["classification_loss_coef"] = 0.0 config["dataset"]["charge_loss_coef"] = 0.0 config["dataset"]["eta_loss_coef"] = 0.0 config["dataset"]["sin_phi_loss_coef"] = 0.0 config["dataset"]["cos_phi_loss_coef"] = 0.0 - config["setup"]["trainable"] = "ffn_energy" - config["setup"]["batch_size"] = 10*config["setup"]["batch_size"] + config["setup"]["trainable"] = "regression" + config["setup"]["batch_size"] = 20*config["setup"]["batch_size"] return config customization_functions = { @@ -472,6 +480,7 @@ def hypertune(config, outdir, ntrain, ntest, recreate): config["dataset"]["num_output_classes"], dataset_def, ) + callbacks.append(optim_callbacks) callbacks.append(tf.keras.callbacks.EarlyStopping(patience=20, monitor='val_loss')) @@ -487,7 +496,7 @@ def hypertune(config, outdir, ntrain, ntest, recreate): #callbacks=[tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss')] callbacks=callbacks, ) - print("Hyperparamter search complete.") + print("Hyperparameter search complete.") shutil.copy(config_file_path, outdir + "/config.yaml") # Copy the config file to the train dir for later reference tuner.results_summary() diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index 609952079..7bf0db3ae 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -37,8 +37,6 @@ def pairwise_gaussian_dist(A, B): def pairwise_learnable_dist(A, B, ffn, training=False): shp = tf.shape(A) - # tf.print("shp", shp) - # import pdb;pdb.set_trace() #stack node feature vectors of src[i], dst[j] into a matrix res[i,j] = (src[i], dst[j]) mg = tf.meshgrid(tf.range(shp[0]), tf.range(shp[1]), tf.range(shp[2]), tf.range(shp[2]), indexing="ij") inds1 = tf.stack([mg[0],mg[1],mg[2]], axis=-1) @@ -384,6 +382,7 @@ def call(self, x_msg, x_node, msk, training=False): n_bins = tf.math.floordiv(n_points, self.bin_size) #put each input item into a bin defined by the argmax output across the LSH embedding + #FIXME: this needs n_bins to be at least 2 to work correctly! mul = tf.linalg.matmul(x_msg, self.codebook_random_rotations[:, :n_bins//2]) cmul = tf.concat([mul, -mul], axis=-1) bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk) @@ -661,7 +660,7 @@ def __init__(self, *args, **kwargs): self.hidden_dim = kwargs.pop("hidden_dim") self.do_lsh = kwargs.pop("do_lsh", True) self.activation = getattr(tf.keras.activations, kwargs.pop("activation")) - self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation")) + self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation", "linear")) if self.do_layernorm: self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm") diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index b818406d4..5383587be 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -662,7 +662,6 @@ def configure_model_weights(model, trainable_layers): cg.trainable = False for cg in model.cg_energy: cg.trainable = True - model.output_dec.set_trainable_regression() elif trainable_layers == "classification": for cg in model.cg: diff --git a/parameters/delphes.yaml b/parameters/delphes.yaml index 902cff14b..6d46b54f0 100644 --- a/parameters/delphes.yaml +++ b/parameters/delphes.yaml @@ -98,6 +98,7 @@ parameters: layernorm: no num_node_messages: 1 dropout: 0.0 + dist_activation: linear kernel: type: NodePairGaussianKernel dist_mult: 0.1 From fa7d31f687d495c1e2290c5139b9b5a1a37ccd94 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Fri, 3 Sep 2021 16:06:00 +0300 Subject: [PATCH 17/17] make sure batch size is propagated Former-commit-id: cf7d9743df6228b65b571747ecd5b667d0fda7af --- mlpf/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py index 2daa8d642..0c80dff1b 100644 --- a/mlpf/pipeline.py +++ b/mlpf/pipeline.py @@ -148,6 +148,8 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq, if customize: prefix += customize + "_" config = customization_functions[customize](config) + #FIXME: refactor this + global_batch_size = config["setup"]["batch_size"] if recreate or (weights is None): outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node()) @@ -419,7 +421,7 @@ def customize_gun_sample(config): config["dataset"]["sin_phi_loss_coef"] = 0.0 config["dataset"]["cos_phi_loss_coef"] = 0.0 config["setup"]["trainable"] = "regression" - config["setup"]["batch_size"] = 20*config["setup"]["batch_size"] + config["setup"]["batch_size"] = 10*config["setup"]["batch_size"] return config customization_functions = {