Skip to content

Commit

Permalink
Comparison job for different event losses (#132)
Browse files Browse the repository at this point in the history
* bin in sin_phi not phi to avoid numerical instability

* use better metrics

* bugfix in hep_tfds, percentiles in notebook

* added logcosh option, cls pt weight, better logging of isr

* fix met in evaluation

* remove history file

* save badly-recod particles

Former-commit-id: 3b972dc
  • Loading branch information
jpata authored Aug 29, 2022
1 parent 84e5c23 commit cb66adf
Show file tree
Hide file tree
Showing 23 changed files with 1,521 additions and 212 deletions.
2 changes: 1 addition & 1 deletion hep_tfds
Submodule hep_tfds updated 1 files
+3 −2 heptfds/cms_utils.py
4 changes: 2 additions & 2 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from comet_ml import OfflineExperiment, Experiment # isort:skip

try:
import horovod.tensorflow.keras as hvd
except ModuleNotFoundError:
Expand Down Expand Up @@ -114,7 +116,6 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
try:
if comet_offline:
print("Using comet-ml OfflineExperiment, saving logs locally.")
from comet_ml import OfflineExperiment

experiment = OfflineExperiment(
project_name="particleflow-tf",
Expand All @@ -127,7 +128,6 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
)
else:
print("Using comet-ml Experiment, streaming logs to www.comet.ml.")
from comet_ml import Experiment

experiment = Experiment(
project_name="particleflow-tf",
Expand Down
6 changes: 4 additions & 2 deletions mlpf/tallinn/cms-mlpf-test.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gpus 2
#SBATCH --gpus 1
#SBATCH --mem-per-gpu=8G

IMG=/home/software/singularity/tf-2.9.0.simg
cd ~/particleflow

env

#TF training
singularity exec -B /scratch-persistent --nv \
--env PYTHONPATH=hep_tfds \
--env TFDS_DATA_DIR=/scratch-persistent/joosep/tensorflow_datasets \
$IMG python mlpf/pipeline.py train -c $1 --plot-freq 1 --ntrain 1000 --ntest 1000
$IMG python mlpf/pipeline.py train -c $1 --plot-freq 1 --ntrain 5000 --ntest 1000
5 changes: 5 additions & 0 deletions mlpf/tallinn/submit-test-eventloss.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/baseline.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/baseline-mask_reg_cls0.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/genjet_logcosh_mask_reg_cls0.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/baseline-clspt.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/swd.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/h2d.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/genjet_mse.yaml
sbatch mlpf/tallinn/cms-mlpf-test.sh parameters/test-eventloss/genjet_logcosh.yaml
6 changes: 5 additions & 1 deletion mlpf/tfmodel/datasets/BaseDatasetFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ def func(data_item):

target = unpack_target(y, num_output_classes, self.cfg)

cls_weights = msk_elems
if self.cfg["dataset"]["cls_weight_by_pt"]:
cls_weights *= target["pt"]

# inputs: X
# targets: dict by classification (cls) and regression feature columns
# weights: dict of weights for each target
return (
X,
target,
{
"cls": msk_elems,
"cls": cls_weights,
"charge": msk_elems * msk_signal,
"pt": msk_elems * msk_signal,
"eta": msk_elems * msk_signal,
Expand Down
25 changes: 9 additions & 16 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ def __init__(
energy_num_layers=3,
layernorm=False,
mask_reg_cls0=True,
energy_multimodal=True,
event_set_output=False,
**kwargs
):
Expand All @@ -539,8 +538,6 @@ def __init__(

self.mask_reg_cls0 = mask_reg_cls0

self.energy_multimodal = energy_multimodal

self.do_layernorm = layernorm
if self.do_layernorm:
self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, name="output_layernorm")
Expand Down Expand Up @@ -598,7 +595,7 @@ def __init__(
)

self.ffn_energy = point_wise_feed_forward_network(
num_output_classes if self.energy_multimodal else 1,
1,
energy_hidden_dim,
"ffn_energy",
num_layers=energy_num_layers,
Expand All @@ -625,7 +622,6 @@ def call(self, args, training=False):
msk_input_outtype = tf.cast(msk_input, out_id_logits.dtype)

out_id_softmax = tf.nn.softmax(out_id_logits, axis=-1)
out_id_hard_softmax = tf.stop_gradient(tf.nn.softmax(100 * out_id_logits, axis=-1))

out_charge = self.ffn_charge(X_encoded, training=training)
out_charge = out_charge * msk_input_outtype
Expand Down Expand Up @@ -665,10 +661,7 @@ def call(self, args, training=False):
pred_energy_corr = pred_energy_corr * msk_input_outtype

# In case of a multimodal prediction, weight the per-class energy predictions by the approximately one-hot vector
if self.energy_multimodal:
pred_energy = orig_energy + tf.reduce_sum(out_id_hard_softmax * pred_energy_corr, axis=-1, keepdims=True)
else:
pred_energy = orig_energy + pred_energy_corr
pred_energy = orig_energy + pred_energy_corr
pred_energy = tf.abs(pred_energy)

# compute pt=E/cosh(eta)
Expand All @@ -682,15 +675,15 @@ def call(self, args, training=False):
pred_pt = tf.abs(pred_pt)

# mask the regression outputs for the nodes with a class prediction 0
msk_output = tf.expand_dims(tf.cast(tf.argmax(out_id_hard_softmax, axis=-1) != 0, tf.float32), axis=-1)
sigmoid_turnon = tf.sigmoid(-out_id_logits[..., 0:1])

if self.mask_reg_cls0:
out_charge = out_charge * msk_output
pred_pt = pred_pt * msk_output
pred_eta = pred_eta * msk_output
pred_sin_phi = pred_sin_phi * msk_output
pred_cos_phi = pred_cos_phi * msk_output
pred_energy = pred_energy * msk_output
out_charge = out_charge * sigmoid_turnon
pred_pt = pred_pt * sigmoid_turnon
pred_eta = pred_eta * sigmoid_turnon
pred_sin_phi = pred_sin_phi * sigmoid_turnon
pred_cos_phi = pred_cos_phi * sigmoid_turnon
pred_energy = pred_energy * sigmoid_turnon

ret = {
"cls": out_id_softmax,
Expand Down
131 changes: 98 additions & 33 deletions mlpf/tfmodel/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import fastjet
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
import tensorflow_addons as tfa
import tf2onnx
Expand Down Expand Up @@ -53,22 +54,23 @@ def on_epoch_end(self, epoch, logs=None):


class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, outpath, dataset, config, plot_freq=1, horovod_enabled=False):
def __init__(self, outpath, dataset, config, plot_freq=1, horovod_enabled=False, comet_experiment=None):
super(CustomCallback, self).__init__()
self.plot_freq = plot_freq
self.dataset = dataset
self.outpath = outpath
self.config = config
self.horovod_enabled = horovod_enabled
self.comet_experiment = comet_experiment

self.writer = tf.summary.create_file_writer(outpath)

def on_epoch_end(self, epoch, logs=None):
if not self.horovod_enabled or hvd.rank() == 0:
epoch_end(self, epoch, logs)
epoch_end(self, epoch, logs, comet_experiment=self.comet_experiment)


def epoch_end(self, epoch, logs):
def epoch_end(self, epoch, logs, comet_experiment=None):
# first epoch is 1, not 0
epoch = epoch + 1

Expand All @@ -92,50 +94,106 @@ def epoch_end(self, epoch, logs):
yvals = {}
for fi in glob.glob(str(cp_dir / "*.npz")):
dd = np.load(fi)
os.remove(fi)
keys_in_file = list(dd.keys())
for k in keys_in_file:
if k == "X":
continue
if not (k in yvals):
yvals[k] = []
yvals[k].append(dd[k])
yvals = {k: np.concatenate(v) for k, v in yvals.items()}

gen_px = yvals["gen_pt"] * yvals["gen_cos_phi"]
gen_py = yvals["gen_pt"] * yvals["gen_sin_phi"]
pred_px = yvals["pred_pt"] * yvals["pred_cos_phi"]
pred_py = yvals["pred_pt"] * yvals["pred_sin_phi"]
cand_px = yvals["cand_pt"] * yvals["cand_cos_phi"]
cand_py = yvals["cand_pt"] * yvals["cand_sin_phi"]
# compute the mask of badly-predicted particles and save to a file bad.npz
denom = np.maximum(yvals["gen_pt"], yvals["pred_pt"])
ratio = np.abs(yvals["gen_pt"] - yvals["pred_pt"]) / denom
ratio[np.isnan(ratio)] = 0
msk_bad = (ratio > 0.8)[:, :, 0]
yvals_bad = {
k: yvals[k][msk_bad]
for k in yvals.keys()
if (k.startswith("gen_") or k.startswith("pred_") or k.startswith("cand_"))
}
print("Number of bad particles: {}".format(len(yvals_bad["gen_cls"])))
with open("{}/bad.npz".format(str(cp_dir)), "wb") as fi:
np.savez(fi, **yvals_bad)

msk_gen = (np.argmax(yvals["gen_cls"], axis=-1, keepdims=True) != 0).astype(np.float32)
gen_px = yvals["gen_pt"] * yvals["gen_cos_phi"] * msk_gen
gen_py = yvals["gen_pt"] * yvals["gen_sin_phi"] * msk_gen

msk_pred = (np.argmax(yvals["pred_cls"], axis=-1, keepdims=True) != 0).astype(np.float32)
pred_px = yvals["pred_pt"] * yvals["pred_cos_phi"] * msk_pred
pred_py = yvals["pred_pt"] * yvals["pred_sin_phi"] * msk_pred

msk_cand = (np.argmax(yvals["cand_cls"], axis=-1, keepdims=True) != 0).astype(np.float32)
cand_px = yvals["cand_pt"] * yvals["cand_cos_phi"] * msk_cand
cand_py = yvals["cand_pt"] * yvals["cand_sin_phi"] * msk_cand

gen_met = np.sqrt(np.sum(gen_px**2 + gen_py**2, axis=1))
pred_met = np.sqrt(np.sum(pred_px**2 + pred_py**2, axis=1))
cand_met = np.sqrt(np.sum(cand_px**2 + cand_py**2, axis=1))

with self.writer.as_default():
jet_ratio = yvals["jets_pt_gen_to_pred"][:, 1] / yvals["jets_pt_gen_to_pred"][:, 0]
jet_ratio_pred = (yvals["jets_pt_gen_to_pred"][:, 1] - yvals["jets_pt_gen_to_pred"][:, 0]) / yvals[
"jets_pt_gen_to_pred"
][:, 0]
jet_ratio_cand = (yvals["jets_pt_gen_to_cand"][:, 1] - yvals["jets_pt_gen_to_cand"][:, 0]) / yvals[
"jets_pt_gen_to_cand"
][:, 0]
met_ratio_pred = (pred_met[:, 0] - gen_met[:, 0]) / gen_met[:, 0]
met_ratio_cand = (cand_met[:, 0] - gen_met[:, 0]) / gen_met[:, 0]

plt.figure()
b = np.linspace(0, 5, 100)
plt.hist(yvals["jets_pt_gen_to_cand"][:, 1] / yvals["jets_pt_gen_to_cand"][:, 0], bins=b, histtype="step", lw=2)
plt.hist(yvals["jets_pt_gen_to_pred"][:, 1] / yvals["jets_pt_gen_to_pred"][:, 0], bins=b, histtype="step", lw=2)
plt.savefig(str(cp_dir / "jet_res.png"), bbox_inches="tight", dpi=100)
b = np.linspace(-2, 5, 100)
plt.hist(jet_ratio_cand, bins=b, histtype="step", lw=2, label="PF")
plt.hist(jet_ratio_pred, bins=b, histtype="step", lw=2, label="MLPF")
plt.xlabel("jet pT (reco-gen)/gen")
plt.ylabel("number of matched jets")
plt.legend(loc="best")
image_path = str(cp_dir / "jet_res.png")
plt.savefig(image_path, bbox_inches="tight", dpi=100)
plt.clf()
if comet_experiment:
comet_experiment.log_image(image_path, step=epoch - 1)

plt.figure()
b = np.linspace(0, 5, 100)
plt.hist(cand_met / gen_met, bins=b, histtype="step", lw=2)
plt.hist(pred_met / gen_met, bins=b, histtype="step", lw=2)
plt.savefig(str(cp_dir / "met_res.png"), bbox_inches="tight", dpi=100)
b = np.linspace(-1, 1, 100)
plt.hist(met_ratio_cand, bins=b, histtype="step", lw=2, label="PF")
plt.hist(met_ratio_pred, bins=b, histtype="step", lw=2, label="MLPF")
plt.xlabel("MET (reco-gen)/gen")
plt.ylabel("number of events")
plt.legend(loc="best")
image_path = str(cp_dir / "met_res.png")
plt.savefig(image_path, bbox_inches="tight", dpi=100)
plt.clf()

tf.summary.histogram("jet_pt_pred_over_gen", jet_ratio, step=epoch - 1, buckets=None, description=None)
tf.summary.scalar("jet_pt_pred_over_gen_mean", np.mean(jet_ratio), step=epoch - 1, description=None)
tf.summary.scalar("jet_pt_pred_over_gen_std", np.std(jet_ratio), step=epoch - 1, description=None)

tf.summary.histogram("met_pred_over_gen", pred_met / gen_met, step=epoch - 1, buckets=None, description=None)
tf.summary.scalar("met_pred_over_gen_mean", np.mean(pred_met / gen_met), step=epoch - 1, description=None)
tf.summary.scalar("met_pred_over_gen_std", np.std(pred_met / gen_met), step=epoch - 1, description=None)
if comet_experiment:
comet_experiment.log_image(image_path, step=epoch - 1)

jet_pred_wd = scipy.stats.wasserstein_distance(
yvals["jets_pt_gen_to_pred"][:, 0], yvals["jets_pt_gen_to_pred"][:, 1]
)
jet_pred_p25 = np.percentile(jet_ratio_pred, 25)
jet_pred_p50 = np.percentile(jet_ratio_pred, 50)
jet_pred_p75 = np.percentile(jet_ratio_pred, 75)
jet_pred_iqr = jet_pred_p75 - jet_pred_p25

met_pred_wd = scipy.stats.wasserstein_distance(gen_met[:, 0], pred_met[:, 0])
met_pred_p25 = np.percentile(met_ratio_pred, 25)
met_pred_p50 = np.percentile(met_ratio_pred, 50)
met_pred_p75 = np.percentile(met_ratio_pred, 75)
met_pred_iqr = met_pred_p75 - met_pred_p25

for name, val in [
("jet_wd", jet_pred_wd),
("jet_iqr", jet_pred_iqr),
("jet_med", jet_pred_p50),
("met_wd", met_pred_wd),
("met_iqr", met_pred_iqr),
("met_med", met_pred_p50),
]:
tf.summary.scalar(name, val, step=epoch - 1, description=None)

if comet_experiment:
comet_experiment.log_metric(name, val, step=epoch - 1)


def prepare_callbacks(
Expand Down Expand Up @@ -179,6 +237,7 @@ def get_checkpoint_history_callback(outdir, config, dataset, comet_experiment, h
config,
plot_freq=config["callbacks"]["plot_freq"],
horovod_enabled=horovod_enabled,
comet_experiment=comet_experiment,
)

callbacks += [cb]
Expand Down Expand Up @@ -282,7 +341,7 @@ def deltar(a, b):

# Given a model, evaluates it on each batch of the validation dataset
# For each batch, save the inputs, the generator-level target, the candidate-level target, and the prediction
def eval_model(model, dataset, config, outdir):
def eval_model(model, dataset, config, outdir, jet_ptcut=5.0, jet_match_dr=0.1):

ibatch = 0

Expand All @@ -291,6 +350,12 @@ def eval_model(model, dataset, config, outdir):
for elem in tqdm(dataset, desc="Evaluating model"):
y_pred = model.predict(elem["X"], verbose=False)

# mask the predictions where there was no predicted particles
msk = (np.argmax(y_pred["cls"], axis=-1, keepdims=True) != 0).astype(np.float32)
for k in y_pred.keys():
if k != "cls":
y_pred[k] = y_pred[k] * msk

np_outfile = "{}/pred_batch{}.npz".format(outdir, ibatch)

ygen = unpack_target(elem["ygen"], config["dataset"]["num_output_classes"], config)
Expand Down Expand Up @@ -321,8 +386,8 @@ def eval_model(model, dataset, config, outdir):

jets = cluster.inclusive_jets()
jet_constituents = cluster.constituent_index()
jets_coll[typ] = jets[jets.pt > 5.0]
jets_const[typ] = jet_constituents[jets.pt > 5.0]
jets_coll[typ] = jets[jets.pt > jet_ptcut]
jets_const[typ] = jet_constituents[jets.pt > jet_ptcut]

for key in ["pt", "eta", "phi", "energy"]:
outs["jets_gen_{}".format(key)] = awkward.to_numpy(awkward.flatten(getattr(jets_coll["gen"], key)))
Expand All @@ -333,7 +398,7 @@ def eval_model(model, dataset, config, outdir):
cart = awkward.cartesian([jets_coll["gen"], jets_coll["pred"]], nested=True)
jets_a, jets_b = awkward.unzip(cart)
drs = deltar(jets_a, jets_b)
match_gen_to_pred = [awkward.where(d < 0.1) for d in drs]
match_gen_to_pred = [awkward.where(d < jet_match_dr) for d in drs]
m0 = awkward.from_iter([m[0] for m in match_gen_to_pred])
m1 = awkward.from_iter([m[1] for m in match_gen_to_pred])
j1s = jets_coll["gen"][m0]
Expand All @@ -346,7 +411,7 @@ def eval_model(model, dataset, config, outdir):
cart = awkward.cartesian([jets_coll["gen"], jets_coll["cand"]], nested=True)
jets_a, jets_b = awkward.unzip(cart)
drs = deltar(jets_a, jets_b)
match_gen_to_pred = [awkward.where(d < 0.1) for d in drs]
match_gen_to_pred = [awkward.where(d < jet_match_dr) for d in drs]
m0 = awkward.from_iter([m[0] for m in match_gen_to_pred])
m1 = awkward.from_iter([m[1] for m in match_gen_to_pred])
j1s = jets_coll["gen"][m0]
Expand Down
Loading

0 comments on commit cb66adf

Please sign in to comment.