Skip to content

Commit

Permalink
standardize input features, re-enable fp16 [TF], unify plotting [pyto…
Browse files Browse the repository at this point in the history
…rch] (#179)

* re-enable fp16

* add data normalizer

* fix for TF eval

* do pytorch prediction in awk
  • Loading branch information
jpata authored Mar 15, 2023
1 parent b4d8945 commit 83acca3
Show file tree
Hide file tree
Showing 15 changed files with 470 additions and 866 deletions.
2 changes: 1 addition & 1 deletion fcc/check_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#pythia card, start seed, end seed
samples = [
("p8_ee_tt_ecm380", 1, 2011),
("p8_ee_tt_ecm380", 1, 10011),
("p8_ee_qq_ecm380", 100001, 110011),
("p8_ee_ZH_Htautau_ecm380", 200001, 202011),
]
Expand Down
4 changes: 4 additions & 0 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ def train(

callbacks.append(optim_callbacks)

model.normalizer.adapt(ds_train.tensorflow_dataset.map(lambda X, y, w: X[:, :, 1:]))
print(model.normalizer.mean)
print(model.normalizer.variance)

model.fit(
ds_train.tensorflow_dataset.repeat(),
validation_data=ds_test.tensorflow_dataset.repeat(),
Expand Down
9 changes: 1 addition & 8 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@
"cms_pf_ttbar": r"CMS $\mathrm{t}\overline{\mathrm{t}}$+PU events",
"cms_pf_single_neutron": r"CMS single neutron particle gun events",
"clic_edm_ttbar_pf": r"CLIC $ee \rightarrow \mathrm{t}\overline{\mathrm{t}}$",
"clic_edm_qcd_pf": r"CLIC $ee \rightarrow \gamma/\mathrm{Z}^* \rightarrow \mathrm{hadrons}$",
"clic_edm_zz_fullhad_pf": r"CLIC $ee \rightarrow \mathrm{ZZ} \rightarrow \mathrm{hadrons}$",
"clic_edm_qq_pf": r"CLIC $ee \rightarrow \gamma/\mathrm{Z}^* \rightarrow \mathrm{hadrons}$",
}


Expand Down Expand Up @@ -167,12 +166,6 @@ def load_eval_data(path, max_files=None):
for k in data["particles"][typ].fields:
yvals["{}_{}".format(typ, k)] = data["particles"][typ][k]

# Get the classification output as a class ID
if "gen_cls_id" not in yvals.keys():
yvals["gen_cls_id"] = np.argmax(yvals["gen_cls"], axis=-1)
yvals["cand_cls_id"] = np.argmax(yvals["cand_cls"], axis=-1)
yvals["pred_cls_id"] = np.argmax(yvals["pred_cls"], axis=-1)

for typ in ["gen", "cand", "pred"]:

# Compute phi, px, py
Expand Down
303 changes: 117 additions & 186 deletions mlpf/pyg/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,48 @@
import glob
import time

import matplotlib
import numpy as np
import torch
import torch_geometric
from pyg.ssl.utils import combine_PFelements, distinguish_PFelements
import awkward
from jet_utils import build_dummy_array, match_two_jet_collections
import fastjet
import vector
import tqdm

from .utils import CLASS_LABELS, Y_FEATURES

matplotlib.use("Agg")
jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0)
jet_pt = 5.0
jet_match_dr = 0.1


def one_hot_embedding(labels, num_classes):
"""
Embedding labels to one-hot form.
def particle_array_to_awkward(batch_ids, arr_id, arr_p4):
ret = {
"cls_id": arr_id,
"pt": arr_p4[:, 1],
"eta": arr_p4[:, 2],
"sin_phi": arr_p4[:, 3],
"cos_phi": arr_p4[:, 4],
"energy": arr_p4[:, 5],
}
ret["phi"] = np.arctan2(ret["sin_phi"], ret["cos_phi"])
ret = awkward.from_iter([{k: ret[k][batch_ids == b] for k in ret.keys()} for b in np.unique(batch_ids)])
return ret

Args:
labels: (LongTensor) class labels, sized [N,].
num_classes: (int) number of classes.
Returns:
(tensor) encoded labels, sized [N, #classes].
"""
y = torch.eye(num_classes)
return y[labels]


def make_predictions(rank, dataset, mlpf, file_loader, batch_size, PATH, ssl_encoder=None):
"""
Runs inference on the qcd test dataset to evaluate performance.
Saves the predictions as .pt files.
Each .pt file will contain a dict() object with keys X, Y_pid, Y_p4;
contains all the necessary event information to make plots.
Args
rank: int representing the gpu device id, or str=='cpu' (both work, trust me)
model: pytorch model
file_loader: a pytorch Dataloader that loads .pt files for training when you invoke the get() method
"""
num_classes = len(CLASS_LABELS[dataset]) # we have 6 classes for delphes and 9 for cms
def make_predictions_awk(rank, dataset, mlpf, file_loader, batch_size, PATH, ssl_encoder=None):

ti = time.time()

ibatch = 0
tf_0, tf_f = time.time(), 0
for num, file in enumerate(file_loader):
for num, this_loader in enumerate(file_loader):
if "utils" in str(type(file_loader)): # it must be converted to a pyg DataLoader if it's not (only needed for CMS)
print(f"Time to load file {num+1}/{len(file_loader)} on rank {rank} is {round(time.time() - tf_0, 3)}s")
tf_f = tf_f + (time.time() - tf_0)
file = torch_geometric.loader.DataLoader([x for t in file for x in t], batch_size=batch_size)
this_loader = torch_geometric.loader.DataLoader([x for t in this_loader for x in t], batch_size=batch_size)

tf = 0
for i, batch in enumerate(file):
for i, batch in tqdm.tqdm(enumerate(this_loader), total=len(this_loader)):

if ssl_encoder is not None:
# seperate PF-elements
Expand All @@ -71,166 +62,106 @@ def make_predictions(rank, dataset, mlpf, file_loader, batch_size, PATH, ssl_enc
pred_ids_one_hot, pred_momentum, pred_charge = mlpf(event)
tf = tf + (time.time() - t0)

pred_charge = torch.argmax(pred_charge, axis=1, keepdim=True) - 1
pred_p4 = torch.cat([pred_charge, pred_momentum], axis=-1)
pred_ids = torch.argmax(pred_ids_one_hot.detach(), axis=-1)
pred_charge = torch.argmax(pred_charge.detach(), axis=1, keepdim=True) - 1
pred_p4 = torch.cat([pred_charge, pred_momentum.detach()], axis=-1)

target_ids = event.ygen_id
target_p4 = event.ygen.to(dtype=torch.float32)
cand_ids = event.ycand_id
cand_p4 = event.ycand.to(dtype=torch.float32)

# zero pad the events to use the same plotting scripts as the tf pipeline
padded_num_elem_size = 6400

# must zero pad each event individually so must unpack the batches
pred_ids_one_hot_list, pred_p4_list = [], []
for z in range(batch_size):
pred_ids_one_hot_list.append(pred_ids_one_hot[batch.batch == z])
pred_p4_list.append(pred_p4[batch.batch == z])

X, Y_pid, Y_p4 = [], [], []
batch_list = batch.to_data_list()
for j, event in enumerate(batch_list):
vars = {
"X": event.x.detach().to("cpu"),
"ygen": target_p4.detach().to("cpu"),
"ycand": cand_p4.detach().to("cpu"),
"pred_p4": pred_p4_list[j].detach().to("cpu"),
"gen_ids_one_hot": one_hot_embedding(target_ids.detach().to("cpu"), num_classes),
"cand_ids_one_hot": one_hot_embedding(cand_ids.detach().to("cpu"), num_classes),
"pred_ids_one_hot": pred_ids_one_hot_list[j].detach().to("cpu"),
}

vars_padded = {}
for key, var in vars.items():
var = var[:padded_num_elem_size]
var = torch.nn.functional.pad(
var,
(0, 0, 0, padded_num_elem_size - var.shape[0]),
mode="constant",
value=0,
).unsqueeze(0)
vars_padded[key] = var

X.append(vars_padded["X"])
Y_pid.append(
torch.cat(
[
vars_padded["gen_ids_one_hot"],
vars_padded["cand_ids_one_hot"],
vars_padded["pred_ids_one_hot"],
]
).unsqueeze(0)
)
Y_p4.append(
torch.cat(
[
vars_padded["ygen"],
vars_padded["ycand"],
vars_padded["pred_p4"],
]
).unsqueeze(0)
)

outfile = f"{PATH}/predictions/pred_batch{ibatch}_{rank}.pt"
print(f"saving predictions at {outfile}")
torch.save(
{
"X": torch.cat(X), # [batch_size, 6400, 41]
"Y_pid": torch.cat(Y_pid), # [batch_size, 3, 6400, 41]
"Y_p4": torch.cat(Y_p4),
}, # [batch_size, 3, 6400, 41]
outfile,
batch_ids = event.batch.cpu().numpy()
awkvals = {
"gen": particle_array_to_awkward(batch_ids, target_ids.cpu().numpy(), event.ygen.cpu().numpy()),
"cand": particle_array_to_awkward(batch_ids, cand_ids.cpu().numpy(), event.ycand.cpu().numpy()),
"pred": particle_array_to_awkward(batch_ids, pred_ids.cpu().numpy(), pred_p4.cpu().numpy()),
}

gen_p4 = []
gen_cls = []
cand_p4 = []
cand_cls = []
pred_p4 = []
pred_cls = []
Xs = []
for _ibatch in np.unique(event.batch.cpu().numpy()):
msk_batch = event.batch == _ibatch
msk_gen = target_ids[msk_batch] != 0
msk_cand = cand_ids[msk_batch] != 0
msk_pred = pred_ids[msk_batch] != 0

Xs.append(event.x[msk_batch].cpu().numpy())

gen_p4.append(event.ygen[msk_batch, 1:][msk_gen])
gen_cls.append(target_ids[msk_batch][msk_gen])

cand_p4.append(event.ycand[msk_batch, 1:][msk_cand])
cand_cls.append(cand_ids[msk_batch][msk_cand])

pred_p4.append(pred_momentum[msk_batch, :][msk_pred])
pred_cls.append(pred_ids[msk_batch][msk_pred])

Xs = awkward.from_iter(Xs)
gen_p4 = awkward.from_iter(gen_p4)
gen_cls = awkward.from_iter(gen_cls)
gen_p4 = vector.awk(
awkward.zip({"pt": gen_p4[:, :, 0], "eta": gen_p4[:, :, 1], "phi": gen_p4[:, :, 2], "e": gen_p4[:, :, 3]})
)

ibatch += 1
cand_p4 = awkward.from_iter(cand_p4)
cand_cls = awkward.from_iter(cand_cls)
cand_p4 = vector.awk(
awkward.zip(
{"pt": cand_p4[:, :, 0], "eta": cand_p4[:, :, 1], "phi": cand_p4[:, :, 2], "e": cand_p4[:, :, 3]}
)
)

# if i == 2:
# break
# if num == 2:
# break
# in case of no predicted particles in the batch
if torch.sum(pred_ids != 0) == 0:
pt = build_dummy_array(len(pred_p4), np.float64)
eta = build_dummy_array(len(pred_p4), np.float64)
phi = build_dummy_array(len(pred_p4), np.float64)
pred_cls = build_dummy_array(len(pred_p4), np.float64)
energy = build_dummy_array(len(pred_p4), np.float64)
pred_p4 = vector.awk(awkward.zip({"pt": pt, "eta": eta, "phi": phi, "e": energy}))
else:
pred_p4 = awkward.from_iter(pred_p4)
pred_cls = awkward.from_iter(pred_cls)
pred_p4 = vector.awk(
awkward.zip(
{
"pt": pred_p4[:, :, 0],
"eta": pred_p4[:, :, 1],
"phi": pred_p4[:, :, 2],
"e": pred_p4[:, :, 3],
}
)
)

print(f"Average inference time per batch on rank {rank} is {(tf / len(file)):.3f}s")
jets_coll = {}

cluster1 = fastjet.ClusterSequence(awkward.Array(gen_p4.to_xyzt()), jetdef)
jets_coll["gen"] = cluster1.inclusive_jets(min_pt=jet_pt)
cluster2 = fastjet.ClusterSequence(awkward.Array(cand_p4.to_xyzt()), jetdef)
jets_coll["cand"] = cluster2.inclusive_jets(min_pt=jet_pt)
cluster3 = fastjet.ClusterSequence(awkward.Array(pred_p4.to_xyzt()), jetdef)
jets_coll["pred"] = cluster3.inclusive_jets(min_pt=jet_pt)

gen_to_pred = match_two_jet_collections(jets_coll, "gen", "pred", jet_match_dr)
gen_to_cand = match_two_jet_collections(jets_coll, "gen", "cand", jet_match_dr)
matched_jets = awkward.Array({"gen_to_pred": gen_to_pred, "gen_to_cand": gen_to_cand})

awkward.to_parquet(
awkward.Array(
{
"inputs": Xs,
"particles": awkvals,
"jets": jets_coll,
"matched_jets": matched_jets,
}
),
f"{PATH}/pred_{i}.parquet",
)

print(f"Average inference time per batch on rank {rank} is {(tf / len(this_loader)):.3f}s")
t0 = time.time()

print(f"Time taken to make predictions on rank {rank} is: {((time.time() - ti) / 60):.2f} min")


def postprocess_predictions(dataset, pred_path):
"""
Loads all the predictions .pt files and combines them after some necessary processing to make plots.
Saves the processed predictions.
"""

print("--> Concatenating all predictions...")
t0 = time.time()

Xs = []
Y_pids = []
Y_p4s = []

PATH = list(glob.glob(f"{pred_path}/pred_batch*.pt"))
for i, fi in enumerate(PATH):
print(f"loading prediction # {i+1}/{len(PATH)}")
dd = torch.load(fi)
Xs.append(dd["X"])
Y_pids.append(dd["Y_pid"])
Y_p4s.append(dd["Y_p4"])

Xs = torch.cat(Xs).numpy()
Y_pids = torch.cat(Y_pids)
Y_p4s = torch.cat(Y_p4s)

# reformat the loaded files for convenient plotting
yvals = {}
yvals["gen_cls"] = Y_pids[:, 0, :, :].numpy()
yvals["cand_cls"] = Y_pids[:, 1, :, :].numpy()
yvals["pred_cls"] = Y_pids[:, 2, :, :].numpy()

for feat, key in enumerate(Y_FEATURES[dataset][1:]): # skip the PDG
yvals[f"gen_{key}"] = Y_p4s[:, 0, :, feat].unsqueeze(-1).numpy()
yvals[f"cand_{key}"] = Y_p4s[:, 1, :, feat].unsqueeze(-1).numpy()
yvals[f"pred_{key}"] = Y_p4s[:, 2, :, feat].unsqueeze(-1).numpy()

print(f"Time taken to concatenate all predictions is: {round(((time.time() - t0) / 60), 2)} min")

print("--> Further processing for convenient plotting")
t0 = time.time()

def flatten(arr):
return arr.reshape(-1, arr.shape[-1])

X_f = flatten(Xs)

msk_X_f = X_f[:, 0] != 0

for val in ["gen", "cand", "pred"]:
if dataset != "CLIC": # TODO: remove
yvals[f"{val}_phi"] = np.arctan2(yvals[f"{val}_sin_phi"], yvals[f"{val}_cos_phi"])
yvals[f"{val}_cls_id"] = np.argmax(yvals[f"{val}_cls"], axis=-1).reshape(
yvals[f"{val}_cls"].shape[0], yvals[f"{val}_cls"].shape[1], 1
) # cz for some reason keepdims doesn't work

yvals[f"{val}_px"] = np.sin(yvals[f"{val}_phi"]) * yvals[f"{val}_pt"]
yvals[f"{val}_py"] = np.cos(yvals[f"{val}_phi"]) * yvals[f"{val}_pt"]

yvals_f = {k: flatten(v) for k, v in yvals.items()}

# remove the last dim
for k in yvals_f.keys():
if yvals_f[k].shape[-1] == 1:
yvals_f[k] = yvals_f[k][..., -1]

print(f"Time taken to process the predictions is: {round(((time.time() - t0) / 60), 2)} min")

print("-->Saving the processed events")
t0 = time.time()
torch.save(Xs, f"{pred_path}/post_processed_Xs.pt", pickle_protocol=4)
torch.save(X_f, f"{pred_path}/post_processed_X_f.pt", pickle_protocol=4)
torch.save(msk_X_f, f"{pred_path}/post_processed_msk_X_f.pt", pickle_protocol=4)
torch.save(yvals, f"{pred_path}/post_processed_yvals.pt", pickle_protocol=4)
torch.save(yvals_f, f"{pred_path}/post_processed_yvals_f.pt", pickle_protocol=4)
print(f"Time taken to save the predictions is: {round(((time.time() - t0) / 60), 2)} min")

return Xs, X_f, msk_X_f, yvals, yvals_f
print(f"Time taken to make predictions on rank {rank} is: {((time.time() - ti) / 60):.2f} min")
Loading

0 comments on commit 83acca3

Please sign in to comment.