diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4a2c8090..44fbf0c6c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,9 +29,13 @@ repos: - repo: https://github.com/psf/black rev: 22.12.0 hooks: - - id: black-jupyter - language_version: python3 - args: [--line-length=125] + - id: black + # It is recommended to specify the latest version of Python + # supported by your project here, or alternatively use + # pre-commit's default_language_version, see + # https://pre-commit.com/#top_level-default_language_version + language_version: python3 + args: [--line-length=125] - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 @@ -42,4 +46,4 @@ repos: # E203 is not PEP8 compliant # E402 due to logging.basicConfig in pipeline.py args: ['--max-line-length=125', # github viewer width - '--extend-ignore=E203,W605,E402'] + '--extend-ignore=E203,E402'] diff --git a/delphes/ntuplizer.py b/delphes/ntuplizer.py index d91c3c719..a84949011 100644 --- a/delphes/ntuplizer.py +++ b/delphes/ntuplizer.py @@ -225,7 +225,13 @@ def make_triplets(g, tracks, towers, particles, pfparticles): # determine the GenParticle to reconstruct from this tower if len(lvs) > 0: lv = sum(lvs[1:], lvs[0]) - gen_ptcl = {"pid": pid, "pt": lv.pt, "eta": lv.eta, "phi": lv.phi, "energy": lv.energy} + gen_ptcl = { + "pid": pid, + "pt": lv.pt, + "eta": lv.eta, + "phi": lv.phi, + "energy": lv.energy, + } # charged gen particles outside the tracker acceptance should be reconstructed as neutrals if gen_ptcl["pid"] == 211 and abs(gen_ptcl["eta"]) > 2.5: @@ -250,7 +256,11 @@ def make_triplets(g, tracks, towers, particles, pfparticles): pf_ptcl = None triplets.append((t, gen_ptcl, pf_ptcl)) - return triplets, list(remaining_particles), list(remaining_pfcandidates) + return ( + triplets, + list(remaining_particles), + list(remaining_pfcandidates), + ) def process_chunk(infile, ev_start, ev_stop, outfile): @@ -380,7 +390,10 @@ def process_chunk(infile, ev_start, ev_stop, outfile): # write the full graph, mainly for study purposes if iev < 10 and save_full_graphs: - nx.readwrite.write_gpickle(graph, outfile.replace(".pkl.bz2", "_graph_{}.pkl".format(iev))) + nx.readwrite.write_gpickle( + graph, + outfile.replace(".pkl.bz2", "_graph_{}.pkl".format(iev)), + ) # now clean up the graph, keeping only reconstructable genparticles # we also merge neutral genparticles within towers, as they are otherwise not reconstructable @@ -390,7 +403,11 @@ def process_chunk(infile, ev_start, ev_stop, outfile): tracks = [n for n in graph.nodes if n[0] == "track"] towers = [n for n in graph.nodes if n[0] == "tower"] - triplets, remaining_particles, remaining_pfcandidates = make_triplets(graph, tracks, towers, particles, pfcand) + ( + triplets, + remaining_particles, + remaining_pfcandidates, + ) = make_triplets(graph, tracks, towers, particles, pfcand) print("remaining PF", len(remaining_pfcandidates)) for pf in remaining_pfcandidates: print(pf, graph.nodes[pf]) @@ -433,7 +450,16 @@ def process_chunk(infile, ev_start, ev_stop, outfile): ygen = np.stack(ygen) ygen_remaining = np.stack(ygen_remaining) ycand = np.stack(ycand) - print("X", X.shape, "ygen", ygen.shape, "ygen_remaining", ygen_remaining.shape, "ycand", ycand.shape) + print( + "X", + X.shape, + "ygen", + ygen.shape, + "ygen_remaining", + ygen_remaining.shape, + "ycand", + ycand.shape, + ) X_all.append(X) ygen_all.append(ygen) diff --git a/mlpf/adv_training.py b/mlpf/adv_training.py index 7a2741a52..48f371d03 100644 --- a/mlpf/adv_training.py +++ b/mlpf/adv_training.py @@ -9,7 +9,10 @@ # A deep sets conditional discriminator def make_disc_model(config, reco_features): input_elems = tf.keras.layers.Input( - shape=(config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]) + shape=( + config["dataset"]["padded_num_elem_size"], + config["dataset"]["num_input_features"], + ) ) input_reco = tf.keras.layers.Input(shape=(config["dataset"]["padded_num_elem_size"], reco_features)) @@ -79,14 +82,21 @@ def main(config): tb.set_model(model_pf) cp_callback = tf.keras.callbacks.ModelCheckpoint( - filepath="logs/weights-{epoch:02d}.hdf5", save_weights_only=True, verbose=0 + filepath="logs/weights-{epoch:02d}.hdf5", + save_weights_only=True, + verbose=0, ) cp_callback.set_model(model_pf) - x = np.random.randn(1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]) + x = np.random.randn( + 1, + config["dataset"]["padded_num_elem_size"], + config["dataset"]["num_input_features"], + ) ypred = concat_pf([model_pf(x), x]) model_pf.load_weights( - "experiments/cms_20210909_132136_111774.gpu0.local/weights/weights-100-1.280379.hdf5", by_name=True + "experiments/cms_20210909_132136_111774.gpu0.local/weights/weights-100-1.280379.hdf5", + by_name=True, ) # model_pf.load_weights("./logs/weights-02.hdf5", by_name=True) @@ -105,18 +115,26 @@ def main(config): cb.set_model(model_pf) input_elems = tf.keras.layers.Input( - shape=(config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]), + shape=( + config["dataset"]["padded_num_elem_size"], + config["dataset"]["num_input_features"], + ), batch_size=2 * batch_size, name="input_detector_elements", ) input_reco = tf.keras.layers.Input( - shape=(config["dataset"]["padded_num_elem_size"], ypred.shape[-1]), name="input_reco_particles" + shape=(config["dataset"]["padded_num_elem_size"], ypred.shape[-1]), + name="input_reco_particles", ) pf_out = tf.keras.layers.Lambda(concat_pf)([model_pf(input_elems), input_elems]) disc_out1 = model_disc([input_elems, pf_out]) disc_out2 = model_disc([input_elems, input_reco]) m1 = tf.keras.models.Model(inputs=[input_elems], outputs=[disc_out1], name="model_mlpf_disc") - m2 = tf.keras.models.Model(inputs=[input_elems, input_reco], outputs=[disc_out2], name="model_reco_disc") + m2 = tf.keras.models.Model( + inputs=[input_elems, input_reco], + outputs=[disc_out2], + name="model_reco_disc", + ) def loss(x, y): return tf.keras.losses.binary_crossentropy(x, y, from_logits=True) @@ -159,7 +177,10 @@ def loss(x, y): mlpf_train_outputs = tf.concat([yb, yp], axis=0) mlpf_train_disc_targets = tf.concat([batch_size * [0.99], batch_size * [0.01]], axis=0) - loss2 = m2.train_on_batch([mlpf_train_inputs, mlpf_train_outputs], mlpf_train_disc_targets) + loss2 = m2.train_on_batch( + [mlpf_train_inputs, mlpf_train_outputs], + mlpf_train_disc_targets, + ) # Train the MLPF reconstruction (generative) model with an inverted target disc_train_disc_targets = tf.concat([batch_size * [1.0]], axis=0) @@ -189,7 +210,10 @@ def loss(x, y): mlpf_train_inputs = tf.concat([xb, xb], axis=0) mlpf_train_outputs = tf.concat([yb, yp], axis=0) mlpf_train_disc_targets = tf.concat([batch_size * [0.99], batch_size * [0.01]], axis=0) - loss2 = m2.test_on_batch([mlpf_train_inputs, mlpf_train_outputs], mlpf_train_disc_targets) + loss2 = m2.test_on_batch( + [mlpf_train_inputs, mlpf_train_outputs], + mlpf_train_disc_targets, + ) # Train the MLPF reconstruction (generative) model with an inverted target disc_train_disc_targets = tf.concat([batch_size * [1.0]], axis=0) diff --git a/mlpf/data_clic/postprocessing.py b/mlpf/data_clic/postprocessing.py index f58678625..a37403017 100644 --- a/mlpf/data_clic/postprocessing.py +++ b/mlpf/data_clic/postprocessing.py @@ -54,7 +54,15 @@ def track_as_array(df_tr, itr): def cluster_as_array(df_cl, icl): row = df_cl[icl] return np.array( - [2, row["x"], row["y"], row["z"], row["nhits_ecal"], row["nhits_hcal"], row["energy"]] # clusters are type 2 + [ + 2, + row["x"], + row["y"], + row["z"], + row["nhits_ecal"], + row["nhits_hcal"], + row["energy"], + ] # clusters are type 2 ) @@ -62,7 +70,16 @@ def cluster_as_array(df_cl, icl): def gen_as_array(df_gen, igen): if igen: row = df_gen[igen] - return np.array([abs(row["pdgid"]), row["charge"], row["px"], row["py"], row["pz"], row["energy"]]) + return np.array( + [ + abs(row["pdgid"]), + row["charge"], + row["px"], + row["py"], + row["pz"], + row["energy"], + ] + ) else: return np.zeros(6) @@ -71,7 +88,16 @@ def gen_as_array(df_gen, igen): def pf_as_array(df_pfs, igen): if igen: row = df_pfs[igen] - return np.array([abs(row["type"]), row["charge"], row["px"], row["py"], row["pz"], row["energy"]]) + return np.array( + [ + abs(row["type"]), + row["charge"], + row["px"], + row["py"], + row["pz"], + row["energy"], + ] + ) else: return np.zeros(6) @@ -145,9 +171,15 @@ def flatten_event(df_tr, df_cl, df_gen, df_pfs, pairs): # Here we pad the tracks and clusters to the same shape along the feature dimension if Xs_tracks.shape[1] > Xs_clusters.shape[-1]: - Xs_clusters = np.pad(Xs_clusters, [(0, 0), (0, Xs_tracks.shape[1] - Xs_clusters.shape[-1])]) + Xs_clusters = np.pad( + Xs_clusters, + [(0, 0), (0, Xs_tracks.shape[1] - Xs_clusters.shape[-1])], + ) elif Xs_tracks.shape[1] < Xs_clusters.shape[-1]: - Xs_clusters = np.pad(Xs_clusters, [(0, 0), (0, Xs_clusters.shape[-1] - Xs_tracks.shape[1])]) + Xs_clusters = np.pad( + Xs_clusters, + [(0, 0), (0, Xs_clusters.shape[-1] - Xs_tracks.shape[1])], + ) Xs = np.concatenate([Xs_tracks, Xs_clusters], axis=0) # [Ntracks+Nclusters, max(Nfeat_cluster, Nfeat_track)] ys_gen = np.stack(ys_gen, axis=-1).T diff --git a/mlpf/data_cms/multicrab.py b/mlpf/data_cms/multicrab.py index 23de25cc2..9e4dafad7 100644 --- a/mlpf/data_cms/multicrab.py +++ b/mlpf/data_cms/multicrab.py @@ -7,15 +7,27 @@ def submit(config): crabCommand("submit", config=config) # save crab config for the future - with open(config.General.workArea + "/crab_" + config.General.requestName + "/crab_config.py", "w") as fi: + with open( + config.General.workArea + "/crab_" + config.General.requestName + "/crab_config.py", + "w", + ) as fi: fi.write(config.pythonise_()) # https://cmsweb.cern.ch/das/request?view=plain&limit=50&instance=prod%2Fglobal&input=%2FRelVal*%2FCMSSW_11_0_0_pre4*%2FGEN-SIM-DIGI-RAW samples = [ - ("/RelValQCD_FlatPt_15_3000HS_14/CMSSW_11_0_0_pre12-PU_110X_mcRun3_2021_realistic_v5-v1/GEN-SIM-DIGI-RAW", "QCD_run3"), - ("/RelValNuGun/CMSSW_11_0_0_pre12-PU_110X_mcRun3_2021_realistic_v5-v1/GEN-SIM-DIGI-RAW", "NuGun_run3"), - ("/RelValTTbar_14TeV/CMSSW_11_0_0_pre12-PU_110X_mcRun3_2021_realistic_v5-v1/GEN-SIM-DIGI-RAW", "TTbar_run3"), + ( + "/RelValQCD_FlatPt_15_3000HS_14/CMSSW_11_0_0_pre12-PU_110X_mcRun3_2021_realistic_v5-v1/GEN-SIM-DIGI-RAW", + "QCD_run3", + ), + ( + "/RelValNuGun/CMSSW_11_0_0_pre12-PU_110X_mcRun3_2021_realistic_v5-v1/GEN-SIM-DIGI-RAW", + "NuGun_run3", + ), + ( + "/RelValTTbar_14TeV/CMSSW_11_0_0_pre12-PU_110X_mcRun3_2021_realistic_v5-v1/GEN-SIM-DIGI-RAW", + "TTbar_run3", + ), # ("/RelValTTbar_14TeV/CMSSW_11_0_0_pre12-PU25ns_110X_mcRun4_realistic_v2_2026D41PU140-v1/GEN-SIM-DIGI-RAW", # "TTbar_run4_pu140"), # ("/RelValTTbar_14TeV/CMSSW_11_0_0_pre12-PU25ns_110X_mcRun4_realistic_v2_2026D41PU200-v1/GEN-SIM-DIGI-RAW", @@ -37,7 +49,10 @@ def submit(config): conf.JobType.psetName = "step3_dump.py" conf.JobType.maxJobRuntimeMin = 8 * 60 conf.JobType.allowUndistributedCMSSW = True - conf.JobType.outputFiles = ["step3_inMINIAODSIM.root", "step3_AOD.root"] + conf.JobType.outputFiles = [ + "step3_inMINIAODSIM.root", + "step3_AOD.root", + ] conf.JobType.maxMemoryMB = 6000 conf.JobType.numCores = 2 diff --git a/mlpf/data_cms/postprocessing2.py b/mlpf/data_cms/postprocessing2.py index 59d604a8e..faf4a312d 100644 --- a/mlpf/data_cms/postprocessing2.py +++ b/mlpf/data_cms/postprocessing2.py @@ -169,7 +169,12 @@ def merge_closeby_particles(g, pid=22, deltar_cut=0.001): if pair[0] in g.nodes and pair[1] in g.nodes: lv = vector.obj(pt=0, eta=0, phi=0, E=0) for gp in pair: - lv += vector.obj(pt=g.nodes[gp]["pt"], eta=g.nodes[gp]["eta"], phi=g.nodes[gp]["phi"], E=g.nodes[gp]["e"]) + lv += vector.obj( + pt=g.nodes[gp]["pt"], + eta=g.nodes[gp]["eta"], + phi=g.nodes[gp]["phi"], + E=g.nodes[gp]["e"], + ) g.nodes[pair[0]]["pt"] = lv.pt g.nodes[pair[0]]["eta"] = lv.eta @@ -346,7 +351,11 @@ def prepare_normalized_table(g, genparticle_energy_threshold=0.2): elems = [e for e in g.successors(gp)] # sort elements by energy deposit from genparticle - elems_sorted = sorted([(g.edges[gp, e]["weight"], e) for e in elems], key=lambda x: x[0], reverse=True) + elems_sorted = sorted( + [(g.edges[gp, e]["weight"], e) for e in elems], + key=lambda x: x[0], + reverse=True, + ) chosen_elem = None for weight, elem in elems_sorted: @@ -363,7 +372,11 @@ def prepare_normalized_table(g, genparticle_energy_threshold=0.2): # assign unmatched genparticles to best element, allowing for overlaps for gp in sorted(unmatched_gp, key=lambda x: g.nodes[x]["e"], reverse=True): elems = [e for e in g.successors(gp)] - elems_sorted = sorted([(g.edges[gp, e]["weight"], e) for e in elems], key=lambda x: x[0], reverse=True) + elems_sorted = sorted( + [(g.edges[gp, e]["weight"], e) for e in elems], + key=lambda x: x[0], + reverse=True, + ) _, elem = elems_sorted[0] elem_to_gp[elem] += [gp] @@ -393,7 +406,11 @@ def prepare_normalized_table(g, genparticle_energy_threshold=0.2): else: # neighbors = [n for n in neighbors if g.nodes[n]["typ"] in [4,5,8,9,10]] # sorted_neighbors = sorted(neighbors, key=lambda x: g.nodes[x]["e"], reverse=True) - sorted_neighbors = sorted(neighbors, key=lambda x: g.edges[(x, cand)]["weight"], reverse=True) + sorted_neighbors = sorted( + neighbors, + key=lambda x: g.edges[(x, cand)]["weight"], + reverse=True, + ) for elem in sorted_neighbors: if not (elem in elem_to_cand): chosen_elem = elem @@ -404,16 +421,29 @@ def prepare_normalized_table(g, genparticle_energy_threshold=0.2): # print("unmatched candidate {}, {}".format(cand, g.nodes[cand])) unmatched_cand += [cand] - Xelem = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in elem_branches]) + Xelem = np.recarray( + (len(all_elements),), + dtype=[(name, np.float32) for name in elem_branches], + ) Xelem.fill(0.0) - ygen = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in target_branches]) + ygen = np.recarray( + (len(all_elements),), + dtype=[(name, np.float32) for name in target_branches], + ) ygen.fill(0.0) - ycand = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in target_branches]) + ycand = np.recarray( + (len(all_elements),), + dtype=[(name, np.float32) for name in target_branches], + ) ycand.fill(0.0) for ielem, elem in enumerate(all_elements): elem_type = g.nodes[elem]["typ"] - genparticles = sorted(elem_to_gp.get(elem, []), key=lambda x: g.edges[(x, elem)]["weight"], reverse=True) + genparticles = sorted( + elem_to_gp.get(elem, []), + key=lambda x: g.edges[(x, elem)]["weight"], + reverse=True, + ) genparticles = [gp for gp in genparticles if g.nodes[gp]["e"] > genparticle_energy_threshold] candidate = elem_to_cand.get(elem, None) @@ -452,7 +482,12 @@ def prepare_normalized_table(g, genparticle_energy_threshold=0.2): charge = g.nodes[genparticles[0]]["charge"] for gp in genparticles: - lv += vector.obj(pt=g.nodes[gp]["pt"], eta=g.nodes[gp]["eta"], phi=g.nodes[gp]["phi"], e=g.nodes[gp]["e"]) + lv += vector.obj( + pt=g.nodes[gp]["pt"], + eta=g.nodes[gp]["eta"], + phi=g.nodes[gp]["phi"], + e=g.nodes[gp]["e"], + ) # remap PID in case of HCAL cluster to neutral if elem_type == 5 and (pid == 22 or pid == 11): @@ -664,7 +699,9 @@ def make_graph(ev, iev): trackingparticle_to_element_cmp = ev["trackingparticle_to_element_cmp"][iev] # for trackingparticles associated to elements, set a very high edge weight for tp, elem, c in zip( - trackingparticle_to_element_first, trackingparticle_to_element_second, trackingparticle_to_element_cmp + trackingparticle_to_element_first, + trackingparticle_to_element_second, + trackingparticle_to_element_cmp, ): if not (g.nodes[("elem", elem)]["typ"] in [7]): g.add_edge(("tp", tp), ("elem", elem), weight=float("inf")) @@ -672,7 +709,11 @@ def make_graph(ev, iev): caloparticle_to_element_first = ev["caloparticle_to_element.first"][iev] caloparticle_to_element_second = ev["caloparticle_to_element.second"][iev] caloparticle_to_element_cmp = ev["caloparticle_to_element_cmp"][iev] - for sc, elem, c in zip(caloparticle_to_element_first, caloparticle_to_element_second, caloparticle_to_element_cmp): + for sc, elem, c in zip( + caloparticle_to_element_first, + caloparticle_to_element_second, + caloparticle_to_element_cmp, + ): if not (g.nodes[("elem", elem)]["typ"] in [7]): g.add_edge(("sc", sc), ("elem", elem), weight=c) @@ -681,7 +722,11 @@ def make_graph(ev, iev): for idx_sc, idx_tp in enumerate(caloparticle_idx_trackingparticle): if idx_tp != -1: for elem in g.neighbors(("sc", idx_sc)): - g.add_edge(("tp", idx_tp), elem, weight=g.edges[("sc", idx_sc), elem]["weight"]) + g.add_edge( + ("tp", idx_tp), + elem, + weight=g.edges[("sc", idx_sc), elem]["weight"], + ) g.nodes[("tp", idx_tp)]["idx_sc"] = idx_sc nodes_to_remove += [("sc", idx_sc)] g.remove_nodes_from(nodes_to_remove) @@ -745,7 +790,12 @@ def process(args): arr_ptcls_pythia = np.array([[g.nodes[n][f] for f in feats] for n in ptcls_pythia]) if args.save_normalized_table: - data = {"Xelem": Xelem, "ycand": ycand, "ygen": ygen, "pythia": arr_ptcls_pythia} + data = { + "Xelem": Xelem, + "ycand": ycand, + "ygen": ygen, + "pythia": arr_ptcls_pythia, + } if args.save_full_graph: data["full_graph"] = g @@ -760,11 +810,29 @@ def parse_args(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("--input", type=str, help="Input file from PFAnalysis", required=True) + parser.add_argument( + "--input", + type=str, + help="Input file from PFAnalysis", + required=True, + ) parser.add_argument("--outpath", type=str, default="raw", help="output path") - parser.add_argument("--save-full-graph", action="store_true", help="save the full event graph") - parser.add_argument("--save-normalized-table", action="store_true", help="save the uniquely identified table") - parser.add_argument("--num-events", type=int, help="number of events to process", default=-1) + parser.add_argument( + "--save-full-graph", + action="store_true", + help="save the full event graph", + ) + parser.add_argument( + "--save-normalized-table", + action="store_true", + help="save the uniquely identified table", + ) + parser.add_argument( + "--num-events", + type=int, + help="number of events to process", + default=-1, + ) args = parser.parse_args() return args diff --git a/mlpf/heptfds/clic_pf/clic_utils.py b/mlpf/heptfds/clic_pf/clic_utils.py index ea54ee057..67ae80262 100644 --- a/mlpf/heptfds/clic_pf/clic_utils.py +++ b/mlpf/heptfds/clic_pf/clic_utils.py @@ -21,7 +21,15 @@ ] # these labels are for clusters from cluster_as_array -X_FEATURES_CL = ["type", "x", "y", "z", "nhits_ecal", "nhits_hcal", "energy"] +X_FEATURES_CL = [ + "type", + "x", + "y", + "z", + "nhits_ecal", + "nhits_hcal", + "energy", +] Y_FEATURES = ["type", "charge", "px", "py", "pz", "energy", "jet_idx"] @@ -35,7 +43,10 @@ def split_sample(path, test_frac=0.8): files_test = files[idx_split:] assert len(files_train) > 0 assert len(files_test) > 0 - return {"train": generate_examples(files_train), "test": generate_examples(files_test)} + return { + "train": generate_examples(files_train), + "test": generate_examples(files_test), + } def generate_examples(files): @@ -47,8 +58,20 @@ def generate_examples(files): for iev, (X, ycand, ygen) in enumerate(ret): # add jet_idx column - ygen = np.concatenate([ygen.astype(np.float32), np.zeros((len(ygen), 1), dtype=np.float32)], axis=-1) - ycand = np.concatenate([ycand.astype(np.float32), np.zeros((len(ycand), 1), dtype=np.float32)], axis=-1) + ygen = np.concatenate( + [ + ygen.astype(np.float32), + np.zeros((len(ygen), 1), dtype=np.float32), + ], + axis=-1, + ) + ycand = np.concatenate( + [ + ycand.astype(np.float32), + np.zeros((len(ycand), 1), dtype=np.float32), + ], + axis=-1, + ) # prepare gen candidates for clustering cls_id = ygen[..., 0] diff --git a/mlpf/heptfds/clic_pf/higgsbb.py b/mlpf/heptfds/clic_pf/higgsbb.py index 69417dbe6..7fc93212e 100644 --- a/mlpf/heptfds/clic_pf/higgsbb.py +++ b/mlpf/heptfds/clic_pf/higgsbb.py @@ -34,7 +34,13 @@ def _info(self) -> tfds.core.DatasetInfo: description=_DESCRIPTION, features=tfds.features.FeaturesDict( { - "X": tfds.features.Tensor(shape=(None, max(len(X_FEATURES_TRK), len(X_FEATURES_CL))), dtype=tf.float32), + "X": tfds.features.Tensor( + shape=( + None, + max(len(X_FEATURES_TRK), len(X_FEATURES_CL)), + ), + dtype=tf.float32, + ), "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), } @@ -43,7 +49,9 @@ def _info(self) -> tfds.core.DatasetInfo: homepage="", citation=_CITATION, metadata=tfds.core.MetadataDict( - x_features_track=X_FEATURES_TRK, x_features_cluster=X_FEATURES_CL, y_features=Y_FEATURES + x_features_track=X_FEATURES_TRK, + x_features_cluster=X_FEATURES_CL, + y_features=Y_FEATURES, ), ) diff --git a/mlpf/heptfds/clic_pf/higgsgg.py b/mlpf/heptfds/clic_pf/higgsgg.py index a1070031c..08fa20f2b 100644 --- a/mlpf/heptfds/clic_pf/higgsgg.py +++ b/mlpf/heptfds/clic_pf/higgsgg.py @@ -34,7 +34,13 @@ def _info(self) -> tfds.core.DatasetInfo: description=_DESCRIPTION, features=tfds.features.FeaturesDict( { - "X": tfds.features.Tensor(shape=(None, max(len(X_FEATURES_TRK), len(X_FEATURES_CL))), dtype=tf.float32), + "X": tfds.features.Tensor( + shape=( + None, + max(len(X_FEATURES_TRK), len(X_FEATURES_CL)), + ), + dtype=tf.float32, + ), "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), } @@ -43,7 +49,9 @@ def _info(self) -> tfds.core.DatasetInfo: homepage="", citation=_CITATION, metadata=tfds.core.MetadataDict( - x_features_track=X_FEATURES_TRK, x_features_cluster=X_FEATURES_CL, y_features=Y_FEATURES + x_features_track=X_FEATURES_TRK, + x_features_cluster=X_FEATURES_CL, + y_features=Y_FEATURES, ), ) diff --git a/mlpf/heptfds/clic_pf/higgszz4l.py b/mlpf/heptfds/clic_pf/higgszz4l.py index c8130feee..d82ebe45a 100644 --- a/mlpf/heptfds/clic_pf/higgszz4l.py +++ b/mlpf/heptfds/clic_pf/higgszz4l.py @@ -34,7 +34,13 @@ def _info(self) -> tfds.core.DatasetInfo: description=_DESCRIPTION, features=tfds.features.FeaturesDict( { - "X": tfds.features.Tensor(shape=(None, max(len(X_FEATURES_TRK), len(X_FEATURES_CL))), dtype=tf.float32), + "X": tfds.features.Tensor( + shape=( + None, + max(len(X_FEATURES_TRK), len(X_FEATURES_CL)), + ), + dtype=tf.float32, + ), "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), } @@ -43,7 +49,9 @@ def _info(self) -> tfds.core.DatasetInfo: homepage="", citation=_CITATION, metadata=tfds.core.MetadataDict( - x_features_track=X_FEATURES_TRK, x_features_cluster=X_FEATURES_CL, y_features=Y_FEATURES + x_features_track=X_FEATURES_TRK, + x_features_cluster=X_FEATURES_CL, + y_features=Y_FEATURES, ), ) diff --git a/mlpf/heptfds/clic_pf/qcd.py b/mlpf/heptfds/clic_pf/qcd.py index 8978da3c2..78fd71c55 100644 --- a/mlpf/heptfds/clic_pf/qcd.py +++ b/mlpf/heptfds/clic_pf/qcd.py @@ -34,7 +34,13 @@ def _info(self) -> tfds.core.DatasetInfo: description=_DESCRIPTION, features=tfds.features.FeaturesDict( { - "X": tfds.features.Tensor(shape=(None, max(len(X_FEATURES_TRK), len(X_FEATURES_CL))), dtype=tf.float32), + "X": tfds.features.Tensor( + shape=( + None, + max(len(X_FEATURES_TRK), len(X_FEATURES_CL)), + ), + dtype=tf.float32, + ), "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), } @@ -43,7 +49,9 @@ def _info(self) -> tfds.core.DatasetInfo: homepage="", citation=_CITATION, metadata=tfds.core.MetadataDict( - x_features_track=X_FEATURES_TRK, x_features_cluster=X_FEATURES_CL, y_features=Y_FEATURES + x_features_track=X_FEATURES_TRK, + x_features_cluster=X_FEATURES_CL, + y_features=Y_FEATURES, ), ) diff --git a/mlpf/heptfds/clic_pf/ttbar.py b/mlpf/heptfds/clic_pf/ttbar.py index af69d2243..2e6182e49 100644 --- a/mlpf/heptfds/clic_pf/ttbar.py +++ b/mlpf/heptfds/clic_pf/ttbar.py @@ -34,7 +34,13 @@ def _info(self) -> tfds.core.DatasetInfo: description=_DESCRIPTION, features=tfds.features.FeaturesDict( { - "X": tfds.features.Tensor(shape=(None, max(len(X_FEATURES_TRK), len(X_FEATURES_CL))), dtype=tf.float32), + "X": tfds.features.Tensor( + shape=( + None, + max(len(X_FEATURES_TRK), len(X_FEATURES_CL)), + ), + dtype=tf.float32, + ), "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), } @@ -43,7 +49,9 @@ def _info(self) -> tfds.core.DatasetInfo: homepage="", citation=_CITATION, metadata=tfds.core.MetadataDict( - x_features_track=X_FEATURES_TRK, x_features_cluster=X_FEATURES_CL, y_features=Y_FEATURES + x_features_track=X_FEATURES_TRK, + x_features_cluster=X_FEATURES_CL, + y_features=Y_FEATURES, ), ) diff --git a/mlpf/heptfds/clic_pf/zpoleee.py b/mlpf/heptfds/clic_pf/zpoleee.py index 220b7cd5d..38be9eb3c 100644 --- a/mlpf/heptfds/clic_pf/zpoleee.py +++ b/mlpf/heptfds/clic_pf/zpoleee.py @@ -34,7 +34,13 @@ def _info(self) -> tfds.core.DatasetInfo: description=_DESCRIPTION, features=tfds.features.FeaturesDict( { - "X": tfds.features.Tensor(shape=(None, max(len(X_FEATURES_TRK), len(X_FEATURES_CL))), dtype=tf.float32), + "X": tfds.features.Tensor( + shape=( + None, + max(len(X_FEATURES_TRK), len(X_FEATURES_CL)), + ), + dtype=tf.float32, + ), "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), } @@ -43,7 +49,9 @@ def _info(self) -> tfds.core.DatasetInfo: homepage="", citation=_CITATION, metadata=tfds.core.MetadataDict( - x_features_track=X_FEATURES_TRK, x_features_cluster=X_FEATURES_CL, y_features=Y_FEATURES + x_features_track=X_FEATURES_TRK, + x_features_cluster=X_FEATURES_CL, + y_features=Y_FEATURES, ), ) diff --git a/mlpf/heptfds/cms_pf/cms_utils.py b/mlpf/heptfds/cms_pf/cms_utils.py index 953c54e4b..cb7229795 100644 --- a/mlpf/heptfds/cms_pf/cms_utils.py +++ b/mlpf/heptfds/cms_pf/cms_utils.py @@ -10,12 +10,42 @@ # https://github.com/ahlinist/cmssw/blob/1df62491f48ef964d198f574cdfcccfd17c70425/DataFormats/ParticleFlowReco/interface/PFBlockElement.h#L33 ELEM_LABELS_CMS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] -ELEM_NAMES_CMS = ["NONE", "TRACK", "PS1", "PS2", "ECAL", "HCAL", "GSF", "BREM", "HFEM", "HFHAD", "SC", "HO"] +ELEM_NAMES_CMS = [ + "NONE", + "TRACK", + "PS1", + "PS2", + "ECAL", + "HCAL", + "GSF", + "BREM", + "HFEM", + "HFHAD", + "SC", + "HO", +] # https://github.com/cms-sw/cmssw/blob/master/DataFormats/ParticleFlowCandidate/src/PFCandidate.cc#L254 CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13] -CLASS_NAMES_CMS = ["none", "ch.had", "n.had", "HFHAD", "HFEM", "gamma", "ele", "mu"] -CLASS_NAMES_LONG_CMS = ["none" "charged hadron", "neutral hadron", "hfem", "hfhad", "photon", "electron", "muon"] +CLASS_NAMES_CMS = [ + "none", + "ch.had", + "n.had", + "HFHAD", + "HFEM", + "gamma", + "ele", + "mu", +] +CLASS_NAMES_LONG_CMS = [ + "none" "charged hadron", + "neutral hadron", + "hfem", + "hfhad", + "photon", + "electron", + "muon", +] X_FEATURES = [ "typ_idx", @@ -61,7 +91,16 @@ "thetaerror", ] -Y_FEATURES = ["typ_idx", "charge", "pt", "eta", "sin_phi", "cos_phi", "e", "jet_idx"] +Y_FEATURES = [ + "typ_idx", + "charge", + "pt", + "eta", + "sin_phi", + "cos_phi", + "e", + "jet_idx", +] def prepare_data_cms(fn): @@ -91,18 +130,35 @@ def prepare_data_cms(fn): ycand = ycand[~msk_ps] Xelem = append_fields( - Xelem, "typ_idx", np.array([ELEM_LABELS_CMS.index(int(i)) for i in Xelem["typ"]], dtype=np.float32) + Xelem, + "typ_idx", + np.array( + [ELEM_LABELS_CMS.index(int(i)) for i in Xelem["typ"]], + dtype=np.float32, + ), ) ygen = append_fields( - ygen, "typ_idx", np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32) + ygen, + "typ_idx", + np.array( + [CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], + dtype=np.float32, + ), ) ygen = append_fields(ygen, "jet_idx", np.zeros(ygen["typ"].shape, dtype=np.float32)) ycand = append_fields( ycand, "typ_idx", - np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32), + np.array( + [CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], + dtype=np.float32, + ), + ) + ycand = append_fields( + ycand, + "jet_idx", + np.zeros(ycand["typ"].shape, dtype=np.float32), ) - ycand = append_fields(ycand, "jet_idx", np.zeros(ycand["typ"].shape, dtype=np.float32)) Xelem_flat = np.stack( [Xelem[k].view(np.float32).data for k in X_FEATURES], @@ -140,7 +196,10 @@ def prepare_data_cms(fn): pt = ygen[valid, Y_FEATURES.index("pt")] eta = ygen[valid, Y_FEATURES.index("eta")] - phi = np.arctan2(ygen[valid, Y_FEATURES.index("sin_phi")], ygen[valid, Y_FEATURES.index("cos_phi")]) + phi = np.arctan2( + ygen[valid, Y_FEATURES.index("sin_phi")], + ygen[valid, Y_FEATURES.index("cos_phi")], + ) e = ygen[valid, Y_FEATURES.index("e")] vec = vector.awk(ak.zip({"pt": pt, "eta": eta, "phi": phi, "e": e})) @@ -179,7 +238,10 @@ def split_sample(path, test_frac=0.8): files_test = files[idx_split:] assert len(files_train) > 0 assert len(files_test) > 0 - return {"train": generate_examples(files_train), "test": generate_examples(files_test)} + return { + "train": generate_examples(files_train), + "test": generate_examples(files_test), + } def generate_examples(files): diff --git a/mlpf/heptfds/delphes_pf/delphes_pf.py b/mlpf/heptfds/delphes_pf/delphes_pf.py index 65a8aaa4d..28c651094 100644 --- a/mlpf/heptfds/delphes_pf/delphes_pf.py +++ b/mlpf/heptfds/delphes_pf/delphes_pf.py @@ -30,7 +30,16 @@ _CITATION = """ """ -DELPHES_CLASS_NAMES = ["none", "charged hadron", "neutral hadron", "hfem", "hfhad", "photon", "electron", "muon"] +DELPHES_CLASS_NAMES = [ + "none", + "charged hadron", + "neutral hadron", + "hfem", + "hfhad", + "photon", + "electron", + "muon", +] # based on delphes/ntuplizer.py X_FEATURES = [ @@ -48,7 +57,16 @@ "is_gen_electron", ] -Y_FEATURES = ["type", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"] +Y_FEATURES = [ + "type", + "charge", + "pt", + "eta", + "sin_phi", + "cos_phi", + "energy", + "jet_idx", +] class DelphesPf(tfds.core.GeneratorBasedBuilder): @@ -140,8 +158,20 @@ def prepare_data_delphes(self, fname): ycand = data["ycand"][i].astype(np.float32) # add jet_idx column - ygen = np.concatenate([ygen.astype(np.float32), np.zeros((len(ygen), 1), dtype=np.float32)], axis=-1) - ycand = np.concatenate([ycand.astype(np.float32), np.zeros((len(ycand), 1), dtype=np.float32)], axis=-1) + ygen = np.concatenate( + [ + ygen.astype(np.float32), + np.zeros((len(ygen), 1), dtype=np.float32), + ], + axis=-1, + ) + ycand = np.concatenate( + [ + ycand.astype(np.float32), + np.zeros((len(ycand), 1), dtype=np.float32), + ], + axis=-1, + ) # prepare gen candidates for clustering cls_id = ygen[..., 0] @@ -154,7 +184,10 @@ def prepare_data_delphes(self, fname): pt = ygen[valid, Y_FEATURES.index("pt")] eta = ygen[valid, Y_FEATURES.index("eta")] - phi = np.arctan2(ygen[valid, Y_FEATURES.index("sin_phi")], ygen[valid, Y_FEATURES.index("cos_phi")]) + phi = np.arctan2( + ygen[valid, Y_FEATURES.index("sin_phi")], + ygen[valid, Y_FEATURES.index("cos_phi")], + ) e = ygen[valid, Y_FEATURES.index("energy")] vec = vector.awk(ak.zip({"pt": pt, "eta": eta, "phi": phi, "e": e})) diff --git a/mlpf/lrp/lrp_mlpf.py b/mlpf/lrp/lrp_mlpf.py index 91183a46c..b2ba8a0d7 100644 --- a/mlpf/lrp/lrp_mlpf.py +++ b/mlpf/lrp/lrp_mlpf.py @@ -131,7 +131,15 @@ def explain_single_layer(self, R_tensor_old, layer_index, neuron_to_explain): # run lrp if "Linear" in str(layer): - R_tensor_new = self.eps_rule(self, layer, layer_name, input, R_tensor_old, neuron_to_explain, msg_passing_layer) + R_tensor_new = self.eps_rule( + self, + layer, + layer_name, + input, + R_tensor_old, + neuron_to_explain, + msg_passing_layer, + ) print("- Finished computing Rscores") return R_tensor_new else: @@ -147,7 +155,15 @@ def explain_single_layer(self, R_tensor_old, layer_index, neuron_to_explain): """ @staticmethod - def eps_rule(self, layer, layer_name, x, R_tensor_old, neuron_to_explain, msg_passing_layer): + def eps_rule( + self, + layer, + layer_name, + x, + R_tensor_old, + neuron_to_explain, + msg_passing_layer, + ): """ Implements the lrp-epsilon rule presented in the following reference: https://doi.org/10.1007/978-3-030-28954-6_10. diff --git a/mlpf/lrp/make_rmaps.py b/mlpf/lrp/make_rmaps.py index 0a4aab912..28d6e7be0 100644 --- a/mlpf/lrp/make_rmaps.py +++ b/mlpf/lrp/make_rmaps.py @@ -73,7 +73,15 @@ def process_Rtensor(node, Rtensor, neighbors): return Rtensor[: neighbors + 1] -def make_Rmaps(outpath, Rtensors, inputs, preds, pid="chhadron", neighbors=2, out_neuron=0): # noqa C901 +def make_Rmaps( + outpath, + Rtensors, + inputs, + preds, + pid="chhadron", + neighbors=2, + out_neuron=0, +): # noqa C901 """ Recall each event has a corresponding Rmap per node in the event. This function process the Rmaps for a given pid. @@ -122,7 +130,10 @@ def make_Rmaps(outpath, Rtensors, inputs, preds, pid="chhadron", neighbors=2, ou node_types = indexing_by_relevance(neighbors + 1, pid) # only plot 6 rows/neighbors in Rmap - for status, var in {"correct": Rtensor_correct, "incorrect": Rtensor_incorrect}.items(): + for status, var in { + "correct": Rtensor_correct, + "incorrect": Rtensor_incorrect, + }.items(): print(f"Making Rmaps for {status}ly classified {pid}") if status == "correct": num = num_Rtensors_correct @@ -154,10 +165,21 @@ def make_Rmaps(outpath, Rtensors, inputs, preds, pid="chhadron", neighbors=2, ou ax.set_yticklabels(node_types, fontsize=20) for col in range(len(features)): for row in range(len(node_types)): - ax.text(col, row, round(var[row, col].item(), 5), ha="center", va="center", color="w", fontsize=14) + ax.text( + col, + row, + round(var[row, col].item(), 5), + ha="center", + va="center", + color="w", + fontsize=14, + ) plt.imshow( - (var[: neighbors + 1] + 1e-12).numpy(), cmap="copper", aspect="auto", norm=matplotlib.colors.LogNorm(vmin=1e-3) + (var[: neighbors + 1] + 1e-12).numpy(), + cmap="copper", + aspect="auto", + norm=matplotlib.colors.LogNorm(vmin=1e-3), ) # create directory to hold Rmaps diff --git a/mlpf/lrp/model.py b/mlpf/lrp/model.py index d73ba5da6..e9eddac1e 100644 --- a/mlpf/lrp/model.py +++ b/mlpf/lrp/model.py @@ -57,7 +57,15 @@ def __init__( self.conv = nn.ModuleList() for i in range(num_convs): - self.conv.append(GravNetConv_LRP(embedding_dim, embedding_dim, space_dim, propagate_dim, k)) + self.conv.append( + GravNetConv_LRP( + embedding_dim, + embedding_dim, + space_dim, + propagate_dim, + k, + ) + ) # (3) DNN layer: classifiying pid self.nn2 = nn.Sequential( @@ -92,7 +100,11 @@ def forward(self, batch): A = {} msg_activations = {} for num, conv in enumerate(self.conv): - embedding, A[f"conv.{num}"], msg_activations[f"conv.{num}"] = conv(embedding) + ( + embedding, + A[f"conv.{num}"], + msg_activations[f"conv.{num}"], + ) = conv(embedding) # predict the pid's preds_id = self.nn2(torch.cat([x0, embedding], axis=-1)) @@ -145,7 +157,11 @@ def reset_parameters(self): self.lin_p.reset_parameters() self.lin_out.reset_parameters() - def forward(self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: + def forward( + self, + x: Union[Tensor, PairTensor], + batch: Union[OptTensor, Optional[PairTensor]] = None, + ) -> Tensor: """""" is_bipartite: bool = True @@ -180,7 +196,12 @@ def forward(self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional A = to_dense_adj(edge_index.to("cpu"), edge_attr=edge_weight.to("cpu"))[0] # adjacency matrix # message passing - out = self.propagate(edge_index, x=(msg_activations, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) + out = self.propagate( + edge_index, + x=(msg_activations, None), + edge_weight=edge_weight, + size=(s_l.size(0), s_r.size(0)), + ) return self.lin_out(out), A, msg_activations @@ -188,7 +209,13 @@ def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return x_j * edge_weight.unsqueeze(1) def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: - out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum") + out_mean = scatter( + inputs, + index, + dim=self.node_dim, + dim_size=dim_size, + reduce="sum", + ) return out_mean def __repr__(self) -> str: diff --git a/mlpf/lrp_mlpf_pipeline.py b/mlpf/lrp_mlpf_pipeline.py index 792602725..e6eba7811 100644 --- a/mlpf/lrp_mlpf_pipeline.py +++ b/mlpf/lrp_mlpf_pipeline.py @@ -10,17 +10,45 @@ parser = argparse.ArgumentParser() # for saving the model -parser.add_argument("--dataset_qcd", type=str, default="../data/delphes/pythia8_qcd", help="testing dataset path") -parser.add_argument("--outpath", type=str, default="../experiments/", help="path to the trained model directory") +parser.add_argument( + "--dataset_qcd", + type=str, + default="../data/delphes/pythia8_qcd", + help="testing dataset path", +) +parser.add_argument( + "--outpath", + type=str, + default="../experiments/", + help="path to the trained model directory", +) parser.add_argument("--load_model", type=str, default="", help="Which model to load") -parser.add_argument("--load_epoch", type=int, default=0, help="Which epoch of the model to load") -parser.add_argument("--out_neuron", type=int, default=0, help="the output neuron you wish to explain") +parser.add_argument( + "--load_epoch", + type=int, + default=0, + help="Which epoch of the model to load", +) +parser.add_argument( + "--out_neuron", + type=int, + default=0, + help="the output neuron you wish to explain", +) parser.add_argument("--pid", type=str, default="chhadron", help="Which model to load") parser.add_argument( - "--n_test", type=int, default=50, help="number of data files to use for testing.. each file contains 100 events" + "--n_test", + type=int, + default=50, + help="number of data files to use for testing.. each file contains 100 events", ) parser.add_argument("--run_lrp", dest="run_lrp", action="store_true", help="runs lrp") -parser.add_argument("--make_rmaps", dest="make_rmaps", action="store_true", help="makes rmaps") +parser.add_argument( + "--make_rmaps", + dest="make_rmaps", + action="store_true", + help="makes rmaps", +) args = parser.parse_args() @@ -40,7 +68,12 @@ # get sample dataset print("Fetching the data..") full_dataset_qcd = PFGraphDataset(args.dataset_qcd) - loader = dataloader_qcd(full_dataset_qcd, multi_gpu=False, n_test=args.n_test, batch_size=1) + loader = dataloader_qcd( + full_dataset_qcd, + multi_gpu=False, + n_test=args.n_test, + batch_size=1, + ) # load a pretrained model and update the outpath outpath = args.outpath + args.load_model @@ -86,5 +119,11 @@ print("Making Rmaps..") make_Rmaps( - args.outpath, Rtensors_list, inputs_list, preds_list, pid=args.pid, neighbors=3, out_neuron=args.out_neuron + args.outpath, + Rtensors_list, + inputs_list, + preds_list, + pid=args.pid, + neighbors=3, + out_neuron=args.out_neuron, ) diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py index 759d4e037..e94a788bc 100644 --- a/mlpf/pipeline.py +++ b/mlpf/pipeline.py @@ -25,7 +25,10 @@ import tqdm from customizations import customization_functions from tfmodel import hypertuning -from tfmodel.datasets.BaseDatasetFactory import mlpf_dataset_from_config, unpack_target +from tfmodel.datasets.BaseDatasetFactory import ( + mlpf_dataset_from_config, + unpack_target, +) from tfmodel.lr_finder import LRFinder from tfmodel.model_setup import eval_model, freeze_model, prepare_callbacks from tfmodel.utils import ( @@ -60,18 +63,71 @@ def main(): @main.command() @click.help_option("-h", "--help") @click.option("-c", "--config", help="configuration file", type=click.Path()) -@click.option("-w", "--weights", default=None, help="trained weights to load", type=click.Path()) -@click.option("--ntrain", default=None, help="override the number of training steps", type=int) -@click.option("--ntest", default=None, help="override the number of testing steps", type=int) -@click.option("--nepochs", default=None, help="override the number of training epochs", type=int) -@click.option("-r", "--recreate", help="force creation of new experiment dir", is_flag=True) -@click.option("-p", "--prefix", default="", help="prefix to put at beginning of training dir name", type=str) -@click.option("--plot-freq", default=None, help="plot detailed validation every N epochs", type=int) +@click.option( + "-w", + "--weights", + default=None, + help="trained weights to load", + type=click.Path(), +) +@click.option( + "--ntrain", + default=None, + help="override the number of training steps", + type=int, +) +@click.option( + "--ntest", + default=None, + help="override the number of testing steps", + type=int, +) +@click.option( + "--nepochs", + default=None, + help="override the number of training epochs", + type=int, +) +@click.option( + "-r", + "--recreate", + help="force creation of new experiment dir", + is_flag=True, +) +@click.option( + "-p", + "--prefix", + default="", + help="prefix to put at beginning of training dir name", + type=str, +) +@click.option( + "--plot-freq", + default=None, + help="plot detailed validation every N epochs", + type=int, +) @click.option("--customize", help="customization function", type=str, default=None) @click.option("--comet-offline", help="log comet-ml experiment locally", is_flag=True) -@click.option("-j", "--jobid", help="log the Slurm job ID in experiments dir", type=str, default=None) -@click.option("-m", "--horovod-enabled", help="Enable multi-node training using Horovod", is_flag=True) -@click.option("-g", "--habana-enabled", help="Enable training on Habana Gaudi", is_flag=True) +@click.option( + "-j", + "--jobid", + help="log the Slurm job ID in experiments dir", + type=str, + default=None, +) +@click.option( + "-m", + "--horovod-enabled", + help="Enable multi-node training using Horovod", + is_flag=True, +) +@click.option( + "-g", + "--habana-enabled", + help="Enable training on Habana Gaudi", + is_flag=True, +) @click.option( "-b", "--benchmark_dir", @@ -80,7 +136,12 @@ def main(): type=str, default=None, ) -@click.option("--batch-multiplier", help="batch size per device", type=int, default=None) +@click.option( + "--batch-multiplier", + help="batch size per device", + type=int, + default=None, +) @click.option("--num-cpus", help="number of CPU threads to use", type=int, default=1) @click.option("--seeds", help="set the random seeds", is_flag=True, default=True) def train( @@ -122,7 +183,8 @@ def train( if config["batching"]["bucket_by_sequence_length"]: logging.info( "Dynamic batching is enabled, changing batch size multiplier from {} to {}".format( - config["batching"]["batch_multiplier"], config["batching"]["batch_multiplier"] * batch_multiplier + config["batching"]["batch_multiplier"], + config["batching"]["batch_multiplier"] * batch_multiplier, ) ) config["batching"]["batch_multiplier"] *= batch_multiplier @@ -259,9 +321,19 @@ def train( @main.command() @click.help_option("-h", "--help") -@click.option("--train-dir", required=True, help="directory containing a completed training", type=click.Path()) +@click.option( + "--train-dir", + required=True, + help="directory containing a completed training", + type=click.Path(), +) @click.option("--config", help="configuration file", type=click.Path()) -@click.option("--weights", default=None, help="trained weights to load", type=click.Path()) +@click.option( + "--weights", + default=None, + help="trained weights to load", + type=click.Path(), +) @click.option("--customize", help="customization function", type=str, default=None) @click.option("--nevents", help="maximum number of events", type=int, default=-1) def evaluate(config, train_dir, weights, customize, nevents): @@ -286,7 +358,12 @@ def evaluate(config, train_dir, weights, customize, nevents): for dsname in config["evaluation_datasets"]: val_ds = config["evaluation_datasets"][dsname] - ds_test = mlpf_dataset_from_config(dsname, config, "test", nevents if nevents >= 0 else val_ds["num_events"]) + ds_test = mlpf_dataset_from_config( + dsname, + config, + "test", + nevents if nevents >= 0 else val_ds["num_events"], + ) ds_test_tfds = ds_test.tensorflow_dataset.padded_batch(val_ds["batch_size"]) eval_dir = str(Path(train_dir) / "evaluation" / "epoch_{}".format(initial_epoch) / dsname) Path(eval_dir).mkdir(parents=True, exist_ok=True) @@ -298,9 +375,27 @@ def evaluate(config, train_dir, weights, customize, nevents): @main.command() @click.help_option("-h", "--help") @click.option("-c", "--config", help="configuration file", type=click.Path()) -@click.option("-o", "--outdir", help="output directory", type=click.Path(), default=".") -@click.option("-n", "--figname", help="name of saved figure", type=click.Path(), default="lr_finder.jpg") -@click.option("-l", "--logscale", help="use log scale on y-axis in figure", default=False, is_flag=True) +@click.option( + "-o", + "--outdir", + help="output directory", + type=click.Path(), + default=".", +) +@click.option( + "-n", + "--figname", + help="name of saved figure", + type=click.Path(), + default="lr_finder.jpg", +) +@click.option( + "-l", + "--logscale", + help="use log scale on y-axis in figure", + default=False, + is_flag=True, +) def find_lr(config, outdir, figname, logscale): """Run the Learning Rate Finder to produce a batch loss vs. LR plot from which an appropriate LR-range can be determined""" @@ -309,7 +404,12 @@ def find_lr(config, outdir, figname, logscale): # Decide tf.distribute.strategy depending on number of available GPUs strategy, num_gpus, num_batches_multiplier = get_strategy() - ds_train, _, _ = get_datasets(config["train_test_datasets"], config, num_batches_multiplier, "train") + ds_train, _, _ = get_datasets( + config["train_test_datasets"], + config, + num_batches_multiplier, + "train", + ) with strategy.scope(): model, _, _ = model_scope(config, 1) @@ -330,7 +430,13 @@ def find_lr(config, outdir, figname, logscale): @main.command() @click.help_option("-h", "--help") @click.option("-t", "--train_dir", help="training directory", type=click.Path()) -@click.option("-d", "--dry_run", help="do not delete anything", is_flag=True, default=False) +@click.option( + "-d", + "--dry_run", + help="do not delete anything", + is_flag=True, + default=False, +) def delete_all_but_best_ckpt(train_dir, dry_run): """Delete all checkpoint weights in /weights/ except the one with lowest loss in its filename.""" delete_all_but_best_checkpoint(train_dir, dry_run) @@ -338,11 +444,33 @@ def delete_all_but_best_ckpt(train_dir, dry_run): @main.command() @click.help_option("-h", "--help") -@click.option("-c", "--config", help="configuration file", type=click.Path(), required=True) +@click.option( + "-c", + "--config", + help="configuration file", + type=click.Path(), + required=True, +) @click.option("-o", "--outdir", help="output dir", type=click.Path(), required=True) -@click.option("--ntrain", default=None, help="override the number of training events", type=int) -@click.option("--ntest", default=None, help="override the number of testing events", type=int) -@click.option("-r", "--recreate", help="overwrite old hypertune results", is_flag=True, default=False) +@click.option( + "--ntrain", + default=None, + help="override the number of training events", + type=int, +) +@click.option( + "--ntest", + default=None, + help="override the number of testing events", + type=int, +) +@click.option( + "-r", + "--recreate", + help="overwrite old hypertune results", + is_flag=True, + default=False, +) @click.option("--num-cpus", help="number of CPU threads to use", type=int, default=1) def hypertune(config, outdir, ntrain, ntest, recreate, num_cpus): config_file_path = config @@ -393,7 +521,13 @@ def hypertune(config, outdir, ntrain, ntest, recreate, num_cpus): def raytune_build_model_and_train( - config, checkpoint_dir=None, full_config=None, ntrain=None, ntest=None, name=None, seeds=False + config, + checkpoint_dir=None, + full_config=None, + ntrain=None, + ntest=None, + name=None, + seeds=False, ): from collections import Counter @@ -512,10 +646,31 @@ def raytune_build_model_and_train( @click.option("--gpus", help="number of gpus per worker", type=int, default=0) @click.option("--tune_result_dir", help="Tune result dir", type=str, default=None) @click.option("-r", "--resume", help="resume run from local_dir", is_flag=True) -@click.option("--ntrain", default=None, help="override the number of training steps", type=int) -@click.option("--ntest", default=None, help="override the number of testing steps", type=int) +@click.option( + "--ntrain", + default=None, + help="override the number of training steps", + type=int, +) +@click.option( + "--ntest", + default=None, + help="override the number of testing steps", + type=int, +) @click.option("-s", "--seeds", help="set the random seeds", is_flag=True) -def raytune(config, name, local, cpus, gpus, tune_result_dir, resume, ntrain, ntest, seeds): +def raytune( + config, + name, + local, + cpus, + gpus, + tune_result_dir, + resume, + ntrain, + ntest, + seeds, +): import ray from ray import tune from ray.tune.logger import TBXLoggerCallback @@ -542,10 +697,12 @@ def raytune(config, name, local, cpus, gpus, tune_result_dir, resume, ntrain, nt expdir = Path(cfg["raytune"]["local_dir"]) / name expdir.mkdir(parents=True, exist_ok=True) shutil.copy( - "mlpf/raytune/search_space.py", str(Path(cfg["raytune"]["local_dir"]) / name / "search_space.py") + "mlpf/raytune/search_space.py", + str(Path(cfg["raytune"]["local_dir"]) / name / "search_space.py"), ) # Copy the config file to the train dir for later reference shutil.copy( - config_file_path, str(Path(cfg["raytune"]["local_dir"]) / name / "config.yaml") + config_file_path, + str(Path(cfg["raytune"]["local_dir"]) / name / "config.yaml"), ) # Copy the config file to the train dir for later reference ray.tune.ray_trial_executor.DEFAULT_GET_TIMEOUT = 1 * 60 * 60 # Avoid timeout errors @@ -560,7 +717,12 @@ def raytune(config, name, local, cpus, gpus, tune_result_dir, resume, ntrain, nt start = datetime.now() analysis = tune.run( partial( - raytune_build_model_and_train, full_config=config_file_path, ntrain=ntrain, ntest=ntest, name=name, seeds=seeds + raytune_build_model_and_train, + full_config=config_file_path, + ntrain=ntrain, + ntest=ntest, + name=name, + seeds=seeds, ), config=search_space, resources_per_trial={"cpu": cpus, "gpu": gpus}, @@ -581,7 +743,10 @@ def raytune(config, name, local, cpus, gpus, tune_result_dir, resume, ntrain, nt logging.info( "Best hyperparameters found according to {} were: ".format(cfg["raytune"]["default_metric"]), - analysis.get_best_config(cfg["raytune"]["default_metric"], cfg["raytune"]["default_mode"]), + analysis.get_best_config( + cfg["raytune"]["default_metric"], + cfg["raytune"]["default_mode"], + ), ) skip = 20 @@ -594,7 +759,10 @@ def raytune(config, name, local, cpus, gpus, tune_result_dir, resume, ntrain, nt summarize_top_k(analysis, k=5, save_dir=Path(analysis.get_best_logdir()).parent) best_params = analysis.get_best_config(cfg["raytune"]["default_metric"], cfg["raytune"]["default_mode"]) - with open(Path(analysis.get_best_logdir()).parent / "best_parameters.txt", "a") as best_params_file: + with open( + Path(analysis.get_best_logdir()).parent / "best_parameters.txt", + "a", + ) as best_params_file: best_params_file.write("Best hyperparameters according to {}\n".format(cfg["raytune"]["default_metric"])) for key, val in best_params.items(): best_params_file.write(("{}: {}\n".format(key, val))) @@ -618,7 +786,12 @@ def count_skipped(exp_dir): @click.help_option("-h", "--help") @click.option("-d", "--exp_dir", help="experiment dir", type=click.Path()) @click.option("-s", "--save", help="save plots in trial dirs", is_flag=True) -@click.option("-k", "--skip", help="skip first values to avoid large losses at start of training", type=int) +@click.option( + "-k", + "--skip", + help="skip first values to avoid large losses at start of training", + type=int, +) @click.option("--metric", help="experiment dir", type=str, default="val_loss") @click.option("--mode", help="experiment dir", type=str, default="min") def raytune_analysis(exp_dir, save, skip, mode, metric): @@ -652,7 +825,11 @@ def test_datasets(config): continue confusion_matrix_Xelem_to_ygen = np.zeros( - (config["dataset"]["num_input_classes"], config["dataset"]["num_output_classes"]), dtype=np.int64 + ( + config["dataset"]["num_input_classes"], + config["dataset"]["num_output_classes"], + ), + dtype=np.int64, ) histograms[dataset] = {} @@ -670,14 +847,16 @@ def test_datasets(config): histograms[dataset]["cand_pt_log"] = bh.Histogram(bh.axis.Regular(100, -1, 5)) histograms[dataset]["sum_gen_cand_energy"] = bh.Histogram( - bh.axis.Regular(100, 0, 100000), bh.axis.Regular(100, 0, 100000) + bh.axis.Regular(100, 0, 100000), + bh.axis.Regular(100, 0, 100000), ) histograms[dataset]["sum_gen_cand_energy_log"] = bh.Histogram( bh.axis.Regular(100, 2, 6), bh.axis.Regular(100, 2, 6) ) histograms[dataset]["sum_gen_cand_pt"] = bh.Histogram( - bh.axis.Regular(100, 0, 100000), bh.axis.Regular(100, 0, 100000) + bh.axis.Regular(100, 0, 100000), + bh.axis.Regular(100, 0, 100000), ) histograms[dataset]["sum_gen_cand_pt_log"] = bh.Histogram(bh.axis.Regular(100, 2, 6), bh.axis.Regular(100, 2, 6)) @@ -698,12 +877,25 @@ def test_datasets(config): # assert ycand.shape[1] == config["dataset"]["num_output_features"] + 1 histograms[dataset]["confusion_matrix_Xelem_to_ygen"] += coo_matrix( - (np.ones(len(X), dtype=np.int64), (np.array(X[:, 0], np.int32), np.array(ygen[:, 0], np.int32))), - shape=(config["dataset"]["num_input_classes"], config["dataset"]["num_output_classes"]), + ( + np.ones(len(X), dtype=np.int64), + ( + np.array(X[:, 0], np.int32), + np.array(ygen[:, 0], np.int32), + ), + ), + shape=( + config["dataset"]["num_input_classes"], + config["dataset"]["num_output_classes"], + ), ).todense() vals_ygen = ygen[ygen[:, 0] != 0] - vals_ygen = unpack_target(vals_ygen, config["dataset"]["num_output_classes"], config) + vals_ygen = unpack_target( + vals_ygen, + config["dataset"]["num_output_classes"], + config, + ) # assert np.all(vals_ygen["energy"] > 0) # assert np.all(vals_ygen["pt"] > 0) # assert not np.any(np.isinf(ygen)) @@ -713,10 +905,17 @@ def test_datasets(config): histograms[dataset]["gen_energy_log"].fill(np.log10(vals_ygen["energy"][:, 0])) histograms[dataset]["gen_pt"].fill(vals_ygen["pt"][:, 0]) histograms[dataset]["gen_pt_log"].fill(np.log10(vals_ygen["pt"][:, 0])) - histograms[dataset]["gen_eta_energy"].fill(vals_ygen["eta"][:, 0], weight=vals_ygen["energy"][:, 0]) + histograms[dataset]["gen_eta_energy"].fill( + vals_ygen["eta"][:, 0], + weight=vals_ygen["energy"][:, 0], + ) vals_ycand = ycand[ycand[:, 0] != 0] - vals_ycand = unpack_target(vals_ycand, config["dataset"]["num_output_classes"], config) + vals_ycand = unpack_target( + vals_ycand, + config["dataset"]["num_output_classes"], + config, + ) # assert(np.all(vals_ycand["energy"]>0)) # assert(np.all(vals_ycand["pt"]>0)) # assert not np.any(np.isinf(ycand)) @@ -726,15 +925,23 @@ def test_datasets(config): histograms[dataset]["cand_energy_log"].fill(np.log10(vals_ycand["energy"][:, 0])) histograms[dataset]["cand_pt"].fill(vals_ycand["pt"][:, 0]) histograms[dataset]["cand_pt_log"].fill(np.log10(vals_ycand["pt"][:, 0])) - histograms[dataset]["cand_eta_energy"].fill(vals_ycand["eta"][:, 0], weight=vals_ycand["energy"][:, 0]) + histograms[dataset]["cand_eta_energy"].fill( + vals_ycand["eta"][:, 0], + weight=vals_ycand["energy"][:, 0], + ) - histograms[dataset]["sum_gen_cand_energy"].fill(np.sum(vals_ygen["energy"]), np.sum(vals_ycand["energy"])) + histograms[dataset]["sum_gen_cand_energy"].fill( + np.sum(vals_ygen["energy"]), + np.sum(vals_ycand["energy"]), + ) histograms[dataset]["sum_gen_cand_energy_log"].fill( - np.log10(np.sum(vals_ygen["energy"])), np.log10(np.sum(vals_ycand["energy"])) + np.log10(np.sum(vals_ygen["energy"])), + np.log10(np.sum(vals_ycand["energy"])), ) histograms[dataset]["sum_gen_cand_pt"].fill(np.sum(vals_ygen["pt"]), np.sum(vals_ycand["pt"])) histograms[dataset]["sum_gen_cand_pt_log"].fill( - np.log10(np.sum(vals_ygen["pt"])), np.log10(np.sum(vals_ycand["pt"])) + np.log10(np.sum(vals_ygen["pt"])), + np.log10(np.sum(vals_ycand["pt"])), ) print(confusion_matrix_Xelem_to_ygen) @@ -748,8 +955,19 @@ def test_datasets(config): @main.command() @click.help_option("-h", "--help") -@click.option("--train-dir", required=True, help="directory containing a completed training", type=click.Path()) -@click.option("--max-files", required=False, help="maximum number of files per dataset to load", type=int, default=None) +@click.option( + "--train-dir", + required=True, + help="directory containing a completed training", + type=click.Path(), +) +@click.option( + "--max-files", + required=False, + help="maximum number of files per dataset to load", + type=int, + default=None, +) def plots(train_dir, max_files): import mplhep from plotting.plot_utils import ( diff --git a/mlpf/plotting/draw_graphs.py b/mlpf/plotting/draw_graphs.py index 9be64e780..ec5014a2c 100644 --- a/mlpf/plotting/draw_graphs.py +++ b/mlpf/plotting/draw_graphs.py @@ -41,7 +41,12 @@ def main(args): hidden_dim = 32 edge_dim = 1 n_iters = 1 - model = EdgeNet(input_dim=input_dim, hidden_dim=hidden_dim, edge_dim=edge_dim, n_iters=n_iters).to(device) + model = EdgeNet( + input_dim=input_dim, + hidden_dim=hidden_dim, + edge_dim=edge_dim, + n_iters=n_iters, + ).to(device) modpath = "data/EdgeNet_14001_ca9bbfb3bb_jduarte.best.pth" model.load_state_dict(torch.load(modpath)) data = data.to(device) @@ -79,7 +84,11 @@ def main(args): seg_args = dict(c="r", alpha=0.5, zorder=2) plt.plot([df[x][i], df[x][j]], [df[y][i], df[y][j]], "-", **seg_args) if "output" in plot_type: - seg_args = dict(c="g", alpha=output[k].item() * (output[k].item() > 0.9), zorder=3) + seg_args = dict( + c="g", + alpha=output[k].item() * (output[k].item() > 0.9), + zorder=3, + ) plt.plot([df[x][i], df[x][j]], [df[y][i], df[y][j]], "-", **seg_args) k += 1 @@ -88,15 +97,54 @@ def main(args): ) cluster_mask = cut_mask & ~df["isTrack"] track_mask = cut_mask & df["isTrack"] - plt.scatter(df[x][cluster_mask], df[y][cluster_mask], c="g", marker="o", s=50, zorder=4, alpha=1) - plt.scatter(df[x][track_mask], df[y][track_mask], c="b", marker="p", s=50, zorder=5, alpha=1) - plt.xlabel("Track or Cluster $\eta$", fontsize=14) - plt.ylabel("Track or Cluster $\phi$", fontsize=14) + plt.scatter( + df[x][cluster_mask], + df[y][cluster_mask], + c="g", + marker="o", + s=50, + zorder=4, + alpha=1, + ) + plt.scatter( + df[x][track_mask], + df[y][track_mask], + c="b", + marker="p", + s=50, + zorder=5, + alpha=1, + ) + plt.xlabel(r"Track or Cluster $\eta$", fontsize=14) + plt.ylabel(r"Track or Cluster $\phi$", fontsize=14) plt.xlim(min_eta, max_eta) plt.ylim(min_phi, max_phi) - plt.figtext(0.12, 0.90, "CMS", fontweight="bold", wrap=True, horizontalalignment="left", fontsize=16) - plt.figtext(0.22, 0.90, "Simulation Preliminary", style="italic", wrap=True, horizontalalignment="left", fontsize=14) - plt.figtext(0.67, 0.90, "Run 3 (14 TeV)", wrap=True, horizontalalignment="left", fontsize=14) + plt.figtext( + 0.12, + 0.90, + "CMS", + fontweight="bold", + wrap=True, + horizontalalignment="left", + fontsize=16, + ) + plt.figtext( + 0.22, + 0.90, + "Simulation Preliminary", + style="italic", + wrap=True, + horizontalalignment="left", + fontsize=14, + ) + plt.figtext( + 0.67, + 0.90, + "Run 3 (14 TeV)", + wrap=True, + horizontalalignment="left", + fontsize=14, + ) plt.savefig("graph_%s_%s_%s.pdf" % (x, y, "_".join(plot_type))) diff --git a/mlpf/plotting/plot_utils.py b/mlpf/plotting/plot_utils.py index 744225ebc..2b5853a88 100644 --- a/mlpf/plotting/plot_utils.py +++ b/mlpf/plotting/plot_utils.py @@ -36,10 +36,32 @@ } ELEM_LABELS_CMS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] -ELEM_NAMES_CMS = ["NONE", "TRACK", "PS1", "PS2", "ECAL", "HCAL", "GSF", "BREM", "HFEM", "HFHAD", "SC", "HO"] +ELEM_NAMES_CMS = [ + "NONE", + "TRACK", + "PS1", + "PS2", + "ECAL", + "HCAL", + "GSF", + "BREM", + "HFEM", + "HFHAD", + "SC", + "HO", +] CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13] -CLASS_NAMES_CMS = ["none", "ch.had", "n.had", "HFHAD", "HFEM", "$\gamma$", "$e^\pm$", "$\mu^\pm$"] +CLASS_NAMES_CMS = [ + r"none", + r"ch.had", + r"n.had", + r"HFHAD", + r"HFEM", + r"$\gamma$", + r"$e^\pm$", + r"$\mu^\pm$", +] EVALUATION_DATASET_NAMES = { "clic_ttbar_pf": r"CLIC $ee \rightarrow \mathrm{t}\overline{\mathrm{t}}$", @@ -77,11 +99,32 @@ def get_fake(df, pid): def cms_label(ax, x0=0.01, x1=0.15, x2=0.98, y=0.94): - plt.figtext(x0, y, "CMS", fontweight="bold", wrap=True, horizontalalignment="left", transform=ax.transAxes) plt.figtext( - x1, y, "Simulation Preliminary", style="italic", wrap=True, horizontalalignment="left", transform=ax.transAxes + x0, + y, + "CMS", + fontweight="bold", + wrap=True, + horizontalalignment="left", + transform=ax.transAxes, + ) + plt.figtext( + x1, + y, + "Simulation Preliminary", + style="italic", + wrap=True, + horizontalalignment="left", + transform=ax.transAxes, + ) + plt.figtext( + x2, + y, + "Run 3 (14 TeV)", + wrap=False, + horizontalalignment="right", + transform=ax.transAxes, ) - plt.figtext(x2, y, "Run 3 (14 TeV)", wrap=False, horizontalalignment="right", transform=ax.transAxes) def sample_label(ax, sample, additional_text="", x=0.01, y=0.87): @@ -90,7 +133,15 @@ def sample_label(ax, sample, additional_text="", x=0.01, y=0.87): def particle_label(ax, pid): - plt.text(0.03, 0.92, pid_to_text[pid], va="top", ha="left", size=10, transform=ax.transAxes) + plt.text( + 0.03, + 0.92, + pid_to_text[pid], + va="top", + ha="left", + size=10, + transform=ax.transAxes, + ) def load_eval_data(path, max_files=None): @@ -143,16 +194,28 @@ def compute_jet_ratio(data, yvals): ret = {} # flatten across event dimension ret["jet_gen_to_pred_genpt"] = awkward.to_numpy( - awkward.flatten(vector.awk(data["jets"]["gen"][data["matched_jets"]["gen_to_pred"]["gen"]]).pt, axis=1) + awkward.flatten( + vector.awk(data["jets"]["gen"][data["matched_jets"]["gen_to_pred"]["gen"]]).pt, + axis=1, + ) ) ret["jet_gen_to_pred_predpt"] = awkward.to_numpy( - awkward.flatten(vector.awk(data["jets"]["pred"][data["matched_jets"]["gen_to_pred"]["pred"]]).pt, axis=1) + awkward.flatten( + vector.awk(data["jets"]["pred"][data["matched_jets"]["gen_to_pred"]["pred"]]).pt, + axis=1, + ) ) ret["jet_gen_to_cand_genpt"] = awkward.to_numpy( - awkward.flatten(vector.awk(data["jets"]["gen"][data["matched_jets"]["gen_to_cand"]["gen"]]).pt, axis=1) + awkward.flatten( + vector.awk(data["jets"]["gen"][data["matched_jets"]["gen_to_cand"]["gen"]]).pt, + axis=1, + ) ) ret["jet_gen_to_cand_candpt"] = awkward.to_numpy( - awkward.flatten(vector.awk(data["jets"]["cand"][data["matched_jets"]["gen_to_cand"]["cand"]]).pt, axis=1) + awkward.flatten( + vector.awk(data["jets"]["cand"][data["matched_jets"]["gen_to_cand"]["cand"]]).pt, + axis=1, + ) ) ret["jet_ratio_pred"] = ret["jet_gen_to_pred_predpt"] / ret["jet_gen_to_pred_genpt"] @@ -205,17 +268,35 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None) pt = awkward.to_numpy(awkward.flatten(yvals["jets_cand_pt"])) p = med_iqr(pt) n = len(pt) - plt.hist(pt, bins=b, histtype="step", lw=2, label="PF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n)) + plt.hist( + pt, + bins=b, + histtype="step", + lw=2, + label="PF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n), + ) pt = awkward.to_numpy(awkward.flatten(yvals["jets_pred_pt"])) p = med_iqr(pt) n = len(pt) - plt.hist(pt, bins=b, histtype="step", lw=2, label="MLPF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n)) + plt.hist( + pt, + bins=b, + histtype="step", + lw=2, + label="MLPF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n), + ) pt = awkward.to_numpy(awkward.flatten(yvals["jets_gen_pt"])) p = med_iqr(pt) n = len(pt) - plt.hist(pt, bins=b, histtype="step", lw=2, label="Gen $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n)) + plt.hist( + pt, + bins=b, + histtype="step", + lw=2, + label="Gen $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n), + ) plt.xscale("log") plt.xlabel("jet $p_T$") @@ -223,7 +304,12 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None) plt.legend(loc="best") if title: plt.title(title) - save_img("jet_pt.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "jet_pt.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) def plot_jet_ratio(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None): @@ -253,25 +339,60 @@ def plot_jet_ratio(yvals, epoch=None, cp_dir=None, comet_experiment=None, title= plt.legend(loc="best") if title: plt.title(title) - save_img("jet_res.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "jet_res.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) def plot_met_and_ratio(met_ratio, epoch=None, cp_dir=None, comet_experiment=None, title=None): # MET plt.figure() - maxval = max([np.max(met_ratio["gen_met"]), np.max(met_ratio["cand_met"]), np.max(met_ratio["pred_met"])]) - minval = min([np.min(met_ratio["gen_met"]), np.min(met_ratio["cand_met"]), np.min(met_ratio["pred_met"])]) + maxval = max( + [ + np.max(met_ratio["gen_met"]), + np.max(met_ratio["cand_met"]), + np.max(met_ratio["pred_met"]), + ] + ) + minval = min( + [ + np.min(met_ratio["gen_met"]), + np.min(met_ratio["cand_met"]), + np.min(met_ratio["pred_met"]), + ] + ) maxval = math.ceil(np.log10(maxval)) minval = math.floor(np.log10(max(minval, 1e-2))) b = np.logspace(minval, maxval, 100) p = med_iqr(met_ratio["cand_met"]) - plt.hist(met_ratio["cand_met"], bins=b, histtype="step", lw=2, label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + met_ratio["cand_met"], + bins=b, + histtype="step", + lw=2, + label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(met_ratio["pred_met"]) - plt.hist(met_ratio["pred_met"], bins=b, histtype="step", lw=2, label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + met_ratio["pred_met"], + bins=b, + histtype="step", + lw=2, + label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(met_ratio["gen_met"]) - plt.hist(met_ratio["gen_met"], bins=b, histtype="step", lw=2, label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + met_ratio["gen_met"], + bins=b, + histtype="step", + lw=2, + label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) plt.xlabel("MET [GeV]") plt.ylabel("Number of events / bin") plt.legend(loc="best") @@ -285,17 +406,32 @@ def plot_met_and_ratio(met_ratio, epoch=None, cp_dir=None, comet_experiment=None b = np.linspace(0, 20, 100) p = med_iqr(met_ratio["ratio_cand"]) - plt.hist(met_ratio["ratio_cand"], bins=b, histtype="step", lw=2, label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + met_ratio["ratio_cand"], + bins=b, + histtype="step", + lw=2, + label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(met_ratio["ratio_pred"]) plt.hist( - met_ratio["ratio_pred"], bins=b, histtype="step", lw=2, label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]) + met_ratio["ratio_pred"], + bins=b, + histtype="step", + lw=2, + label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), ) plt.xlabel("MET reco/gen") plt.ylabel("number of events") plt.legend(loc="best") if title: plt.title(title) - save_img("met_res.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "met_res.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) def compute_distances(distribution_1, distribution_2, ratio): @@ -325,7 +461,12 @@ def plot_num_elements(X, epoch=None, cp_dir=None, comet_experiment=None, title=N plt.ylabel("Number of events / bin") if title: plt.title(title) - save_img("num_elements.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "num_elements.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) def plot_sum_energy(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None): @@ -334,8 +475,20 @@ def plot_sum_energy(yvals, epoch=None, cp_dir=None, comet_experiment=None, title sum_cand_energy = awkward.to_numpy(awkward.sum(yvals["cand_energy"], axis=1)) sum_pred_energy = awkward.to_numpy(awkward.sum(yvals["pred_energy"], axis=1)) - max_e = max([np.max(sum_gen_energy), np.max(sum_cand_energy), np.max(sum_pred_energy)]) - min_e = min([np.min(sum_gen_energy), np.min(sum_cand_energy), np.min(sum_pred_energy)]) + max_e = max( + [ + np.max(sum_gen_energy), + np.max(sum_cand_energy), + np.max(sum_pred_energy), + ] + ) + min_e = min( + [ + np.min(sum_gen_energy), + np.min(sum_cand_energy), + np.min(sum_pred_energy), + ] + ) max_e = int(1.2 * max_e) min_e = int(0.8 * min_e) @@ -348,7 +501,12 @@ def plot_sum_energy(yvals, epoch=None, cp_dir=None, comet_experiment=None, title plt.ylabel("total PF energy / event [GeV]") if title: plt.title(title) - save_img("sum_gen_cand_energy.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "sum_gen_cand_energy.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) plt.figure() plt.hist2d(sum_gen_energy, sum_pred_energy, bins=(b, b), cmap="hot_r") @@ -357,10 +515,27 @@ def plot_sum_energy(yvals, epoch=None, cp_dir=None, comet_experiment=None, title plt.ylabel("total MLPF energy / event [GeV]") if title: plt.title(title) - save_img("sum_gen_pred_energy.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "sum_gen_pred_energy.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) - max_e = max([np.max(sum_gen_energy), np.max(sum_cand_energy), np.max(sum_pred_energy)]) - min_e = min([np.min(sum_gen_energy), np.min(sum_cand_energy), np.min(sum_pred_energy)]) + max_e = max( + [ + np.max(sum_gen_energy), + np.max(sum_cand_energy), + np.max(sum_pred_energy), + ] + ) + min_e = min( + [ + np.min(sum_gen_energy), + np.min(sum_cand_energy), + np.min(sum_pred_energy), + ] + ) max_e = math.ceil(np.log10(max_e)) min_e = math.floor(np.log10(max(min_e, 1e-2))) @@ -369,24 +544,44 @@ def plot_sum_energy(yvals, epoch=None, cp_dir=None, comet_experiment=None, title plt.hist2d(sum_gen_energy, sum_cand_energy, bins=(b, b), cmap="hot_r") plt.xscale("log") plt.yscale("log") - plt.plot([10**min_e, 10**max_e], [10**min_e, 10**max_e], color="black", ls="--") + plt.plot( + [10**min_e, 10**max_e], + [10**min_e, 10**max_e], + color="black", + ls="--", + ) plt.xlabel("total true energy / event [GeV]") plt.ylabel("total reconstructed energy / event [GeV]") if title: plt.title(title + ", PF") - save_img("sum_gen_cand_energy_log.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "sum_gen_cand_energy_log.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) b = np.logspace(min_e, max_e, 100) plt.figure() plt.hist2d(sum_gen_energy, sum_pred_energy, bins=(b, b), cmap="hot_r") plt.xscale("log") plt.yscale("log") - plt.plot([10**min_e, 10**max_e], [10**min_e, 10**max_e], color="black", ls="--") + plt.plot( + [10**min_e, 10**max_e], + [10**min_e, 10**max_e], + color="black", + ls="--", + ) plt.xlabel("total true energy / event [GeV]") plt.ylabel("total reconstructed energy / event [GeV]") if title: plt.title(title + ", MLPF") - save_img("sum_gen_pred_energy_log.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "sum_gen_pred_energy_log.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None): @@ -402,18 +597,41 @@ def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title= b = np.logspace(-1, 4, 100) plt.figure() p = med_iqr(cand_pt) - plt.hist(cand_pt, bins=b, histtype="step", lw=2, label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + cand_pt, + bins=b, + histtype="step", + lw=2, + label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(pred_pt) - plt.hist(pred_pt, bins=b, histtype="step", lw=2, label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + pred_pt, + bins=b, + histtype="step", + lw=2, + label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(gen_pt) - plt.hist(gen_pt, bins=b, histtype="step", lw=2, label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + gen_pt, + bins=b, + histtype="step", + lw=2, + label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) plt.xscale("log") plt.xlabel("Particle $p_T$ [GeV]") plt.ylabel("Number of particles / bin") if title: plt.title(title) plt.legend(loc="best") - save_img("particle_pt.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "particle_pt.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) msk_cand = yvals["cand_cls_id"] != 0 cand_pt = awkward.to_numpy(awkward.flatten(yvals["cand_eta"][msk_cand], axis=1)) @@ -427,17 +645,40 @@ def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title= b = np.linspace(-8, 8, 100) plt.figure() p = med_iqr(cand_pt) - plt.hist(cand_pt, bins=b, histtype="step", lw=2, label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + cand_pt, + bins=b, + histtype="step", + lw=2, + label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(pred_pt) - plt.hist(pred_pt, bins=b, histtype="step", lw=2, label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + pred_pt, + bins=b, + histtype="step", + lw=2, + label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) p = med_iqr(gen_pt) - plt.hist(gen_pt, bins=b, histtype="step", lw=2, label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1])) + plt.hist( + gen_pt, + bins=b, + histtype="step", + lw=2, + label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]), + ) plt.xlabel(r"Particle $\eta$") plt.ylabel("Number of particles / bin") if title: plt.title(title) plt.legend(loc="best") - save_img("particle_eta.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "particle_eta.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) msk_cand = yvals["cand_cls_id"] != 0 msk_pred = yvals["pred_cls_id"] != 0 @@ -455,7 +696,12 @@ def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title= plt.plot([10**-1, 10**4], [10**-1, 10**4], color="black", ls="--") if title: plt.title(title + ", PF") - save_img("particle_pt_gen_vs_pf.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "particle_pt_gen_vs_pf.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) pred_pt = awkward.to_numpy(awkward.flatten(yvals["pred_pt"][msk_pred & msk_gen], axis=1)) gen_pt = awkward.to_numpy(awkward.flatten(yvals["gen_pt"][msk_pred & msk_gen], axis=1)) @@ -469,4 +715,9 @@ def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title= plt.plot([10**-1, 10**4], [10**-1, 10**4], color="black", ls="--") if title: plt.title(title + ", MLPF") - save_img("particle_pt_gen_vs_mlpf.png", epoch, cp_dir=cp_dir, comet_experiment=comet_experiment) + save_img( + "particle_pt_gen_vs_mlpf.png", + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) diff --git a/mlpf/plotting/plots_cms.py b/mlpf/plotting/plots_cms.py index d89c757b5..a6c229fdd 100644 --- a/mlpf/plotting/plots_cms.py +++ b/mlpf/plotting/plots_cms.py @@ -3,7 +3,12 @@ import matplotlib.pyplot as plt import numpy as np import sklearn.metrics -from plot_utils import plot_confusion_matrix, plot_E_reso, plot_eta_reso, plot_phi_reso +from plot_utils import ( + plot_confusion_matrix, + plot_E_reso, + plot_eta_reso, + plot_phi_reso, +) class_labels = list(range(8)) @@ -20,9 +25,39 @@ def prepare_resolution_plots(big_df, pid, bins, target="cand", outpath="./"): v1 = big_df[["{}_eta".format(target), "pred_eta"]].values v2 = big_df[["{}_phi".format(target), "pred_phi"]].values - plot_E_reso(big_df, pid, v0, msk_true, msk_pred, msk_both, bins, target=target, outpath=outpath) - plot_eta_reso(big_df, pid, v1, msk_true, msk_pred, msk_both, bins, target=target, outpath=outpath) - plot_phi_reso(big_df, pid, v2, msk_true, msk_pred, msk_both, bins, target=target, outpath=outpath) + plot_E_reso( + big_df, + pid, + v0, + msk_true, + msk_pred, + msk_both, + bins, + target=target, + outpath=outpath, + ) + plot_eta_reso( + big_df, + pid, + v1, + msk_true, + msk_pred, + msk_both, + bins, + target=target, + outpath=outpath, + ) + plot_phi_reso( + big_df, + pid, + v2, + msk_true, + msk_pred, + msk_both, + bins, + target=target, + outpath=outpath, + ) def load_np(npfile): @@ -41,7 +76,11 @@ def flatten(arr): parser = argparse.ArgumentParser() parser.add_argument( - "--target", type=str, choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", default="cand" + "--target", + type=str, + choices=["cand", "gen"], + help="Regress to PFCandidates or GenParticles", + default="cand", ) parser.add_argument("--input", type=str, required=True) args = parser.parse_args() @@ -55,9 +94,16 @@ def flatten(arr): confusion = sklearn.metrics.confusion_matrix(ycand_flat[msk, 0], ypred_flat[msk, 0], labels=range(8)) - fig, ax = plot_confusion_matrix(cm=confusion, target_names=[int(x) for x in class_labels], normalize=True) + fig, ax = plot_confusion_matrix( + cm=confusion, + target_names=[int(x) for x in class_labels], + normalize=True, + ) - plt.savefig(osp.join(osp.dirname(args.input), "confusion_mlpf.pdf"), bbox_inches="tight") + plt.savefig( + osp.join(osp.dirname(args.input), "confusion_mlpf.pdf"), + bbox_inches="tight", + ) # prepare_resolution_plots(big_df, 211, bins[211], target=args.target, outpath=osp.dirname(args.input)) # prepare_resolution_plots(big_df, 130, bins[130], target=args.target, outpath=osp.dirname(args.input)) diff --git a/mlpf/pyg/PFGraphDataset.py b/mlpf/pyg/PFGraphDataset.py index f9463d9f1..7f0aede6d 100644 --- a/mlpf/pyg/PFGraphDataset.py +++ b/mlpf/pyg/PFGraphDataset.py @@ -140,8 +140,19 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, required=True, help="'cms' or 'delphes'?") parser.add_argument("--dataset", type=str, required=True, help="Input data path") - parser.add_argument("--processed_dir", type=str, help="processed", required=False, default=None) - parser.add_argument("--num-files-merge", type=int, default=10, help="number of files to merge") + parser.add_argument( + "--processed_dir", + type=str, + help="processed", + required=False, + default=None, + ) + parser.add_argument( + "--num-files-merge", + type=int, + default=10, + help="number of files to merge", + ) parser.add_argument("--num-proc", type=int, default=24, help="number of processes") args = parser.parse_args() return args diff --git a/mlpf/pyg/args.py b/mlpf/pyg/args.py index 47b5ffe1c..bf8c19af2 100644 --- a/mlpf/pyg/args.py +++ b/mlpf/pyg/args.py @@ -5,11 +5,26 @@ def parse_args(): parser = argparse.ArgumentParser() # for data loading - parser.add_argument("--num_workers", type=int, default=2, help="number of subprocesses used for data loading") - parser.add_argument("--prefetch_factor", type=int, default=4, help="number of samples loaded in advance by each worker") + parser.add_argument( + "--num_workers", + type=int, + default=2, + help="number of subprocesses used for data loading", + ) + parser.add_argument( + "--prefetch_factor", + type=int, + default=4, + help="number of samples loaded in advance by each worker", + ) # for saving the model - parser.add_argument("--outpath", type=str, default="../experiments/", help="output folder") + parser.add_argument( + "--outpath", + type=str, + default="../experiments/", + help="output folder", + ) parser.add_argument( "--model_prefix", type=str, @@ -19,31 +34,73 @@ def parse_args(): # for loading the data parser.add_argument("--data", type=str, required=True, help="cms or delphes?") - parser.add_argument("--dataset", type=str, default="../data/delphes/pythia8_ttbar", help="training dataset path") - parser.add_argument("--dataset_test", type=str, default="../data/delphes/pythia8_qcd", help="testing dataset path") + parser.add_argument( + "--dataset", + type=str, + default="../data/delphes/pythia8_ttbar", + help="training dataset path", + ) + parser.add_argument( + "--dataset_test", + type=str, + default="../data/delphes/pythia8_qcd", + help="testing dataset path", + ) parser.add_argument("--sample", type=str, default="QCD", help="sample to test on") parser.add_argument( - "--n_train", type=int, default=1, help="number of data files to use for training.. each file contains 100 events" + "--n_train", + type=int, + default=1, + help="number of data files to use for training.. each file contains 100 events", ) parser.add_argument( - "--n_valid", type=int, default=1, help="number of data files to use for validation.. each file contains 100 events" + "--n_valid", + type=int, + default=1, + help="number of data files to use for validation.. each file contains 100 events", ) parser.add_argument( - "--n_test", type=int, default=1, help="number of data files to use for testing.. each file contains 100 events" + "--n_test", + type=int, + default=1, + help="number of data files to use for testing.. each file contains 100 events", ) - parser.add_argument("--overwrite", dest="overwrite", action="store_true", help="Overwrites the model if True") + parser.add_argument( + "--overwrite", + dest="overwrite", + action="store_true", + help="Overwrites the model if True", + ) # for loading a pre-trained model - parser.add_argument("--load", dest="load", action="store_true", help="Load the model (no training)") - parser.add_argument("--load_epoch", type=int, default=-1, help="Which epoch of the model to load for evaluation") + parser.add_argument( + "--load", + dest="load", + action="store_true", + help="Load the model (no training)", + ) + parser.add_argument( + "--load_epoch", + type=int, + default=-1, + help="Which epoch of the model to load for evaluation", + ) # for training hyperparameters parser.add_argument("--n_epochs", type=int, default=3, help="number of training epochs") parser.add_argument( - "--batch_size", type=int, default=1, help="number of events to run inference on before updating the loss" + "--batch_size", + type=int, + default=1, + help="number of events to run inference on before updating the loss", + ) + parser.add_argument( + "--patience", + type=int, + default=30, + help="patience before early stopping", ) - parser.add_argument("--patience", type=int, default=30, help="patience before early stopping") parser.add_argument( "--target", type=str, @@ -67,24 +124,61 @@ def parse_args(): # for model architecture parser.add_argument( - "--hidden_dim1", type=int, default=126, help="hidden dimension of layers before the graph convolutions" + "--hidden_dim1", + type=int, + default=126, + help="hidden dimension of layers before the graph convolutions", + ) + parser.add_argument( + "--hidden_dim2", + type=int, + default=256, + help="hidden dimension of layers after the graph convolutions", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=32, + help="encoded element dimension", ) parser.add_argument( - "--hidden_dim2", type=int, default=256, help="hidden dimension of layers after the graph convolutions" + "--num_convs", + type=int, + default=3, + help="number of graph convolutions", ) - parser.add_argument("--embedding_dim", type=int, default=32, help="encoded element dimension") - parser.add_argument("--num_convs", type=int, default=3, help="number of graph convolutions") - parser.add_argument("--space_dim", type=int, default=4, help="Spatial dimension for clustering in gravnet layer") parser.add_argument( - "--propagate_dim", type=int, default=8, help="The number of features to be propagated between the vertices" + "--space_dim", + type=int, + default=4, + help="Spatial dimension for clustering in gravnet layer", + ) + parser.add_argument( + "--propagate_dim", + type=int, + default=8, + help="The number of features to be propagated between the vertices", + ) + parser.add_argument( + "--nearest", + type=int, + default=4, + help="k nearest neighbors in gravnet layer", ) - parser.add_argument("--nearest", type=int, default=4, help="k nearest neighbors in gravnet layer") # for testing the model parser.add_argument( - "--make_predictions", dest="make_predictions", action="store_true", help="run inference on the test data" + "--make_predictions", + dest="make_predictions", + action="store_true", + help="run inference on the test data", + ) + parser.add_argument( + "--make_plots", + dest="make_plots", + action="store_true", + help="makes plots of the test predictions", ) - parser.add_argument("--make_plots", dest="make_plots", action="store_true", help="makes plots of the test predictions") args = parser.parse_args() diff --git a/mlpf/pyg/cms_plots.py b/mlpf/pyg/cms_plots.py index 426342113..f9b511d4c 100644 --- a/mlpf/pyg/cms_plots.py +++ b/mlpf/pyg/cms_plots.py @@ -5,17 +5,42 @@ import mplhep import numpy as np import sklearn.metrics -from pyg.cms_utils import CLASS_LABELS_CMS, CLASS_NAMES_CMS, CLASS_NAMES_CMS_LATEX +from pyg.cms_utils import ( + CLASS_LABELS_CMS, + CLASS_NAMES_CMS, + CLASS_NAMES_CMS_LATEX, +) mplhep.style.use(mplhep.styles.CMS) def cms_label(ax, x0=0.01, x1=0.15, x2=0.98, y=0.94): - plt.figtext(x0, y, "CMS", fontweight="bold", wrap=True, horizontalalignment="left", transform=ax.transAxes) plt.figtext( - x1, y, "Simulation Preliminary", style="italic", wrap=True, horizontalalignment="left", transform=ax.transAxes + x0, + y, + "CMS", + fontweight="bold", + wrap=True, + horizontalalignment="left", + transform=ax.transAxes, + ) + plt.figtext( + x1, + y, + "Simulation Preliminary", + style="italic", + wrap=True, + horizontalalignment="left", + transform=ax.transAxes, + ) + plt.figtext( + x2, + y, + "Run 3 (14 TeV)", + wrap=False, + horizontalalignment="right", + transform=ax.transAxes, ) - plt.figtext(x2, y, "Run 3 (14 TeV)", wrap=False, horizontalalignment="right", transform=ax.transAxes) # def cms_label_sample_label(x0=0.12, x1=0.23, x2=0.67, y=0.90): @@ -26,9 +51,21 @@ def cms_label(ax, x0=0.01, x1=0.15, x2=0.98, y=0.94): def sample_label(sample, ax, additional_text="", x=0.01, y=0.87): if sample == "QCD": - plt.text(x, y, "QCD events" + additional_text, ha="left", transform=ax.transAxes) + plt.text( + x, + y, + "QCD events" + additional_text, + ha="left", + transform=ax.transAxes, + ) else: - plt.text(x, y, "$\mathrm{t}\overline{\mathrm{t}}$ events" + additional_text, ha="left", transform=ax.transAxes) + plt.text( + x, + y, + r"$\mathrm{t}\overline{\mathrm{t}}$ events" + additional_text, + ha="left", + transform=ax.transAxes, + ) def apply_thresholds_f(ypred_raw_f, thresholds): @@ -96,14 +133,14 @@ def plot_met(X, yvals, outpath, sample): bins=b, histtype="step", lw=2, - label="PF, $\mu={:.2f}$, $\sigma={:.2f}$".format(np.mean(vals_a), np.std(vals_a)), + label=r"PF, $\mu={:.2f}$, $\sigma={:.2f}$".format(np.mean(vals_a), np.std(vals_a)), ) plt.hist( vals_b, bins=b, histtype="step", lw=2, - label="MLPF, $\mu={:.2f}$, $\sigma={:.2f}$".format(np.mean(vals_b), np.std(vals_b)), + label=r"MLPF, $\mu={:.2f}$, $\sigma={:.2f}$".format(np.mean(vals_b), np.std(vals_b)), ) plt.yscale("log") cms_label(ax) @@ -120,14 +157,24 @@ def plot_sum_energy(X, yvals, outpath, sample): plt.figure() ax = plt.axes() - plt.scatter(np.sum(yvals["gen_energy"], axis=1), np.sum(yvals["cand_energy"], axis=1), alpha=0.5, label="PF") - plt.scatter(np.sum(yvals["gen_energy"], axis=1), np.sum(yvals["pred_energy"], axis=1), alpha=0.5, label="MLPF") + plt.scatter( + np.sum(yvals["gen_energy"], axis=1), + np.sum(yvals["cand_energy"], axis=1), + alpha=0.5, + label="PF", + ) + plt.scatter( + np.sum(yvals["gen_energy"], axis=1), + np.sum(yvals["pred_energy"], axis=1), + alpha=0.5, + label="MLPF", + ) plt.plot([10000, 80000], [10000, 80000], color="black") plt.legend(loc=4) cms_label(ax) sample_label(sample, ax) - plt.xlabel("Gen $\sum E$ [GeV]") - plt.ylabel("Reconstructed $\sum E$ [GeV]") + plt.xlabel(r"Gen $\sum E$ [GeV]") + plt.ylabel(r"Reconstructed $\sum E$ [GeV]") plt.savefig(f"{outpath}/sum_energy.pdf", bbox_inches="tight") plt.close() @@ -137,14 +184,24 @@ def plot_sum_pt(X, yvals, outpath, sample): plt.figure() ax = plt.axes() - plt.scatter(np.sum(yvals["gen_pt"], axis=1), np.sum(yvals["cand_pt"], axis=1), alpha=0.5, label="PF") - plt.scatter(np.sum(yvals["gen_pt"], axis=1), np.sum(yvals["pred_pt"], axis=1), alpha=0.5, label="PF") + plt.scatter( + np.sum(yvals["gen_pt"], axis=1), + np.sum(yvals["cand_pt"], axis=1), + alpha=0.5, + label="PF", + ) + plt.scatter( + np.sum(yvals["gen_pt"], axis=1), + np.sum(yvals["pred_pt"], axis=1), + alpha=0.5, + label="PF", + ) plt.plot([1000, 6000], [1000, 6000], color="black") plt.legend(loc=4) cms_label(ax) sample_label(sample, ax) - plt.xlabel("Gen $\sum p_T$ [GeV]") - plt.ylabel("Reconstructed $\sum p_T$ [GeV]") + plt.xlabel(r"Gen $\sum p_T$ [GeV]") + plt.ylabel(r"Reconstructed $\sum p_T$ [GeV]") plt.savefig(f"{outpath}/sum_pt.pdf", bbox_inches="tight") plt.close() @@ -166,14 +223,14 @@ def plot_energy_res(X, yvals_f, pid, b, ylim, outpath, sample): bins=b, histtype="step", lw=2, - label="PF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_1), np.std(reso_1)), + label=r"PF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_1), np.std(reso_1)), ) plt.hist( reso_2, bins=b, histtype="step", lw=2, - label="MLPF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_2), np.std(reso_2)), + label=r"MLPF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_2), np.std(reso_2)), ) plt.yscale("log") plt.xlabel(r"$\frac{E_\mathrm{reco} - E_\mathrm{gen}}{E_\mathrm{gen}}$") @@ -182,7 +239,10 @@ def plot_energy_res(X, yvals_f, pid, b, ylim, outpath, sample): sample_label(sample, ax, f", {CLASS_NAMES_CMS_LATEX[pid]}") plt.legend(loc=(0.4, 0.7)) plt.ylim(1, ylim) - plt.savefig(f"{outpath}/energy_res_{CLASS_NAMES_CMS[pid]}.pdf", bbox_inches="tight") + plt.savefig( + f"{outpath}/energy_res_{CLASS_NAMES_CMS[pid]}.pdf", + bbox_inches="tight", + ) plt.close() @@ -205,14 +265,14 @@ def plot_eta_res(X, yvals_f, pid, ylim, outpath, sample): bins=b, histtype="step", lw=2, - label="PF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_1), np.std(reso_1)), + label=r"PF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_1), np.std(reso_1)), ) plt.hist( reso_2, bins=b, histtype="step", lw=2, - label="MLPF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_2), np.std(reso_2)), + label=r"MLPF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_2), np.std(reso_2)), ) plt.yscale("log") plt.xlabel(r"$\eta_\mathrm{reco} - \eta_\mathrm{gen}$") @@ -221,7 +281,10 @@ def plot_eta_res(X, yvals_f, pid, ylim, outpath, sample): sample_label(sample, ax, f", {CLASS_NAMES_CMS_LATEX[pid]}") plt.legend(loc=(0.0, 0.7)) plt.ylim(1, ylim) - plt.savefig(f"{outpath}/eta_res_{CLASS_NAMES_CMS[pid]}.pdf", bbox_inches="tight") + plt.savefig( + f"{outpath}/eta_res_{CLASS_NAMES_CMS[pid]}.pdf", + bbox_inches="tight", + ) plt.close() @@ -264,8 +327,8 @@ def plot_multiplicity(X, yvals, outpath, sample): plt.plot([minval, maxval], [minval, maxval], color="black") plt.xlim(minval, maxval) plt.ylim(minval, maxval) - plt.xlabel("true $\sum E$ [GeV]") - plt.xlabel("reconstructed $\sum E$ [GeV]") + plt.xlabel(r"true $\sum E$ [GeV]") + plt.xlabel(r"reconstructed $\sum E$ [GeV]") plt.legend(loc=4) cms_label(ax) sample_label(sample, ax, f", {CLASS_NAMES_CMS_LATEX[icls]}") @@ -295,7 +358,10 @@ def plot_dist(yvals_f, var, bin, label, outpath, sample): plt.figure() ax = plt.axes() v1 = mplhep.histplot( - [h[bh.rebin(2)] for h in hists_gen], stack=True, label=[class_names[k] for k in [13, 11, 22, 1, 2, 130, 211]], lw=1 + [h[bh.rebin(2)] for h in hists_gen], + stack=True, + label=[class_names[k] for k in [13, 11, 22, 1, 2, 130, 211]], + lw=1, ) mplhep.histplot( [h[bh.rebin(2)] for h in hists_pred], @@ -305,7 +371,12 @@ def plot_dist(yvals_f, var, bin, label, outpath, sample): histtype="errorbar", ) - legend1 = plt.legend(v1, [x.legend_artist.get_label() for x in v1], loc=(0.60, 0.44), title="true") + legend1 = plt.legend( + v1, + [x.legend_artist.get_label() for x in v1], + loc=(0.60, 0.44), + title="true", + ) # legend2 = plt.legend(v2, [x.legend_artist.get_label() for x in v1], loc=(0.8, 0.44), title="pred") plt.gca().add_artist(legend1) plt.ylabel("Total number of particles / bin") @@ -334,21 +405,34 @@ def plot_eff_and_fake_rate( values = X_f[:, ivar] - hist_gen = np.histogram(values[(yvals_f["gen_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_gen = np.histogram( + values[(yvals_f["gen_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, + ) hist_gen_pred = np.histogram( - values[(yvals_f["gen_cls_id"] == icls) & (yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins + values[(yvals_f["gen_cls_id"] == icls) & (yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, ) hist_gen_cand = np.histogram( - values[(yvals_f["gen_cls_id"] == icls) & (yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins + values[(yvals_f["gen_cls_id"] == icls) & (yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, ) - hist_pred = np.histogram(values[(yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) - hist_cand = np.histogram(values[(yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_pred = np.histogram( + values[(yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, + ) + hist_cand = np.histogram( + values[(yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, + ) hist_pred_fake = np.histogram( - values[(yvals_f["gen_cls_id"] != icls) & (yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins + values[(yvals_f["gen_cls_id"] != icls) & (yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, ) hist_cand_fake = np.histogram( - values[(yvals_f["gen_cls_id"] != icls) & (yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins + values[(yvals_f["gen_cls_id"] != icls) & (yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], + bins=bins, ) eff_mlpf = hist_gen_pred[0] / hist_gen[0] @@ -423,7 +507,10 @@ def plot_cm(yvals_f, msk_X_f, label, outpath): Y = yvals_f["cand_cls_id"][msk_X_f] cm_norm = sklearn.metrics.confusion_matrix( - yvals_f["gen_cls_id"][msk_X_f], Y, labels=range(0, len(CLASS_LABELS_CMS)), normalize="true" + yvals_f["gen_cls_id"][msk_X_f], + Y, + labels=range(0, len(CLASS_LABELS_CMS)), + normalize="true", ) plt.imshow(cm_norm, cmap="Blues", origin="lower") @@ -442,7 +529,11 @@ def plot_cm(yvals_f, msk_X_f, label, outpath): cms_label(ax, y=1.01) # cms_label_sample_label(x1=0.18, x2=0.52, y=0.82) - plt.xticks(range(len(CLASS_NAMES_CMS_LATEX)), CLASS_NAMES_CMS_LATEX, rotation=45) + plt.xticks( + range(len(CLASS_NAMES_CMS_LATEX)), + CLASS_NAMES_CMS_LATEX, + rotation=45, + ) plt.yticks(range(len(CLASS_NAMES_CMS_LATEX)), CLASS_NAMES_CMS_LATEX) plt.xlabel(f"{label} candidate ID") @@ -456,7 +547,12 @@ def plot_cm(yvals_f, msk_X_f, label, outpath): def distribution_icls(yvals_f, outpath): for icls in range(0, 8): fig, axs = plt.subplots( - 2, 2, figsize=(2 * mplhep.styles.CMS["figure.figsize"][0], 2 * mplhep.styles.CMS["figure.figsize"][1]) + 2, + 2, + figsize=( + 2 * mplhep.styles.CMS["figure.figsize"][0], + 2 * mplhep.styles.CMS["figure.figsize"][1], + ), ) for ax, ivar in zip(axs.flatten(), ["pt", "energy", "eta", "phi"]): @@ -479,7 +575,14 @@ def distribution_icls(yvals_f, outpath): b = np.linspace(np.min(vals_true), np.max(vals_true), 41) log = False - plt.hist(vals_true, bins=b, histtype="step", lw=2, label="gen", color="black") + plt.hist( + vals_true, + bins=b, + histtype="step", + lw=2, + label="gen", + color="black", + ) plt.hist(vals_pf, bins=b, histtype="step", lw=2, label="PF") plt.hist(vals_pred, bins=b, histtype="step", lw=2, label="MLPF") plt.legend(loc=(0.75, 0.75)) diff --git a/mlpf/pyg/cms_utils.py b/mlpf/pyg/cms_utils.py index a8f350347..3fdd263b3 100644 --- a/mlpf/pyg/cms_utils.py +++ b/mlpf/pyg/cms_utils.py @@ -10,16 +10,66 @@ # https://github.com/ahlinist/cmssw/blob/1df62491f48ef964d198f574cdfcccfd17c70425/DataFormats/ParticleFlowReco/interface/PFBlockElement.h#L33 ELEM_LABELS_CMS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] -ELEM_NAMES_CMS = ["NONE", "TRACK", "PS1", "PS2", "ECAL", "HCAL", "GSF", "BREM", "HFEM", "HFHAD", "SC", "HO"] +ELEM_NAMES_CMS = [ + "NONE", + "TRACK", + "PS1", + "PS2", + "ECAL", + "HCAL", + "GSF", + "BREM", + "HFEM", + "HFHAD", + "SC", + "HO", +] # https://github.com/cms-sw/cmssw/blob/master/DataFormats/ParticleFlowCandidate/src/PFCandidate.cc#L254 CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13, 15] -CLASS_NAMES_CMS_LATEX = ["none", "chhad", "nhad", "HFEM", "HFHAD", "$\gamma$", "$e^\pm$", "$\mu^\pm$", r"$\tau$"] -CLASS_NAMES_CMS = ["none", "chhad", "nhad", "HFEM", "HFHAD", "gamma", "ele", "mu", "tau"] +CLASS_NAMES_CMS_LATEX = [ + r"none", + r"chhad", + r"nhad", + r"HFEM", + r"HFHAD", + r"$\gamma$", + r"$e^\pm$", + r"$\mu^\pm$", + r"$\tau$", +] +CLASS_NAMES_CMS = [ + "none", + "chhad", + "nhad", + "HFEM", + "HFHAD", + "gamma", + "ele", + "mu", + "tau", +] -CLASS_NAMES_LONG_CMS = ["none" "charged hadron", "neutral hadron", "hfem", "hfhad", "photon", "electron", "muon", "tau"] +CLASS_NAMES_LONG_CMS = [ + "none" "charged hadron", + "neutral hadron", + "hfem", + "hfhad", + "photon", + "electron", + "muon", + "tau", +] -CMS_PF_CLASS_NAMES = ["none" "charged hadron", "neutral hadron", "hfem", "hfhad", "photon", "electron", "muon"] +CMS_PF_CLASS_NAMES = [ + "none" "charged hadron", + "neutral hadron", + "hfem", + "hfhad", + "photon", + "electron", + "muon", +] X_FEATURES = [ "typ_idx", @@ -98,15 +148,28 @@ def prepare_data_cms(fn): ycand = ycand[~msk_ps] Xelem = append_fields( - Xelem, "typ_idx", np.array([ELEM_LABELS_CMS.index(int(i)) for i in Xelem["typ"]], dtype=np.float32) + Xelem, + "typ_idx", + np.array( + [ELEM_LABELS_CMS.index(int(i)) for i in Xelem["typ"]], + dtype=np.float32, + ), ) ygen = append_fields( - ygen, "typ_idx", np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32) + ygen, + "typ_idx", + np.array( + [CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], + dtype=np.float32, + ), ) ycand = append_fields( ycand, "typ_idx", - np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32), + np.array( + [CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], + dtype=np.float32, + ), ) Xelem_flat = np.stack( diff --git a/mlpf/pyg/delphes_plots.py b/mlpf/pyg/delphes_plots.py index cec6624d0..c3ac52398 100644 --- a/mlpf/pyg/delphes_plots.py +++ b/mlpf/pyg/delphes_plots.py @@ -64,8 +64,8 @@ "energy": r"$E$", } var_names_bare = { - "pt": "p_\mathrm{T}", - "eta": "\eta", + "pt": r"p_\mathrm{T}", + "eta": r"\eta", "energy": "E", } var_indices = { @@ -105,7 +105,17 @@ def divide_zero(a, b): return out -def plot_distribution(data, pid, target, mlpf, var_name, rng, target_type, fname, legend_title=""): +def plot_distribution( + data, + pid, + target, + mlpf, + var_name, + rng, + target_type, + fname, + legend_title="", +): """ plot distributions for the target and mlpf of a given feature for a given PID """ @@ -119,9 +129,23 @@ def plot_distribution(data, pid, target, mlpf, var_name, rng, target_type, fname fig = plt.figure(figsize=(10, 10)) if target_type == "cand": - plt.hist(target, bins=rng, density=True, histtype="step", lw=2, label="cand") + plt.hist( + target, + bins=rng, + density=True, + histtype="step", + lw=2, + label="cand", + ) elif target_type == "gen": - plt.hist(target, bins=rng, density=True, histtype="step", lw=2, label="gen") + plt.hist( + target, + bins=rng, + density=True, + histtype="step", + lw=2, + label="gen", + ) plt.hist(mlpf, bins=rng, density=True, histtype="step", lw=2, label="MLPF") plt.xlabel(var_name) @@ -139,7 +163,18 @@ def plot_distribution(data, pid, target, mlpf, var_name, rng, target_type, fname def plot_distributions_pid( - data, pid, true_id, true_p4, pred_id, pred_p4, pf_id, cand_p4, target, epoch, outpath, legend_title="" + data, + pid, + true_id, + true_p4, + pred_id, + pred_p4, + pf_id, + cand_p4, + target, + epoch, + outpath, + legend_title="", ): """ plot distributions for the target and mlpf of the regressed features for a given PID @@ -168,7 +203,17 @@ def plot_distributions_pid( def plot_distributions_all( - data, true_id, true_p4, pred_id, pred_p4, pf_id, cand_p4, target, epoch, outpath, legend_title="" + data, + true_id, + true_p4, + pred_id, + pred_p4, + pf_id, + cand_p4, + target, + epoch, + outpath, + legend_title="", ): """ plot distributions for the target and mlpf of a all features, merging all PIDs @@ -228,7 +273,7 @@ def plot_particle_multiplicity(data, list, key, ax=None, legend_title=""): target_list[key], cand_list[key], marker="o", - label="Rule-based PF, $r={0:.3f}$\n$\mu={1:.3f}\\ \sigma={2:.3f}$".format( + label=r"Rule-based PF, $r={0:.3f}$\n$\mu={1:.3f}\\ \sigma={2:.3f}$".format( np.corrcoef(a, b)[0, 1], mu_dpf, sigma_dpf ), alpha=0.5, @@ -247,7 +292,7 @@ def plot_particle_multiplicity(data, list, key, ax=None, legend_title=""): target_list[key], cand_list[key], marker="^", - label="MLPF, $r={0:.3f}$\n$\mu={1:.3f}\\ \sigma={2:.3f}$".format(np.corrcoef(a, b)[0, 1], mu_mlpf, sigma_mlpf), + label=r"MLPF, $r={0:.3f}$\n$\mu={1:.3f}\\ \sigma={2:.3f}$".format(np.corrcoef(a, b)[0, 1], mu_mlpf, sigma_mlpf), alpha=0.5, ) @@ -267,7 +312,18 @@ def plot_particle_multiplicity(data, list, key, ax=None, legend_title=""): plt.title("Particle multiplicity") -def draw_efficiency_fakerate(data, ygen, ypred, ycand, pid, var, bins, outpath, both=True, legend_title=""): +def draw_efficiency_fakerate( + data, + ygen, + ypred, + ycand, + pid, + var, + bins, + outpath, + both=True, + legend_title="", +): if data == "delphes": pid_to_name = pid_to_name_delphes elif data == "cms": @@ -398,25 +454,46 @@ def plot_reso(data, ygen, ypred, ycand, pfcand, var, outpath, legend_title=""): fig, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.hist( - ratio_dpf, bins=bins, histtype="step", lw=2, label="Rule-based PF\n$\mu={:.2f},\\ \sigma={:.2f}$".format(*res_dpf) + ratio_dpf, + bins=bins, + histtype="step", + lw=2, + label=r"Rule-based PF\n$\mu={:.2f},\\ \sigma={:.2f}$".format(*res_dpf), + ) + ax.hist( + ratio_mlpf, + bins=bins, + histtype="step", + lw=2, + label=r"MLPF\n$\mu={:.2f},\\ \sigma={:.2f}$".format(*res_mlpf), ) - ax.hist(ratio_mlpf, bins=bins, histtype="step", lw=2, label="MLPF\n$\mu={:.2f},\\ \sigma={:.2f}$".format(*res_mlpf)) ax.legend(frameon=False, title=legend_title + pfcand) ax.set_xlabel( - "{nounit} resolution, $({bare}^\prime - {bare})/{bare}$".format( + r"{nounit} resolution, $({bare}^\prime - {bare})/{bare}$".format( nounit=var_names_nounit[var], bare=var_names_bare[var] ) ) ax.set_ylabel("Particles") ax.set_ylim(1, 1e10) ax.set_yscale("log") - plt.savefig(outpath + f"/resolution_plots/res_{pfcand}_{var}.pdf", bbox_inches="tight") + plt.savefig( + outpath + f"/resolution_plots/res_{pfcand}_{var}.pdf", + bbox_inches="tight", + ) plt.tight_layout() plt.close(fig) def plot_confusion_matrix( - cm, target_names, epoch, outpath, save_as, title="Confusion matrix", cmap=None, normalize=True, target=None + cm, + target_names, + epoch, + outpath, + save_as, + title="Confusion matrix", + cmap=None, + normalize=True, + target=None, ): """ given a sklearn confusion matrix (cm), make a nice plot @@ -494,7 +571,11 @@ def plot_confusion_matrix( ) else: plt.text( - j, i, "{:,}".format(cm[i, j]), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black" + j, + i, + "{:,}".format(cm[i, j]), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black", ) plt.ylabel("True label") diff --git a/mlpf/pyg/delphes_utils.py b/mlpf/pyg/delphes_utils.py index 0633400e2..a1fe2f499 100644 --- a/mlpf/pyg/delphes_utils.py +++ b/mlpf/pyg/delphes_utils.py @@ -23,9 +23,30 @@ def make_predictions_delphes(model, multi_gpu, test_loader, outpath, device, epo print("Making predictions...") t0 = time.time() - gen_list = {"null": [], "chhadron": [], "nhadron": [], "photon": [], "electron": [], "muon": []} - pred_list = {"null": [], "chhadron": [], "nhadron": [], "photon": [], "electron": [], "muon": []} - cand_list = {"null": [], "chhadron": [], "nhadron": [], "photon": [], "electron": [], "muon": []} + gen_list = { + "null": [], + "chhadron": [], + "nhadron": [], + "photon": [], + "electron": [], + "muon": [], + } + pred_list = { + "null": [], + "chhadron": [], + "nhadron": [], + "photon": [], + "electron": [], + "muon": [], + } + cand_list = { + "null": [], + "chhadron": [], + "nhadron": [], + "photon": [], + "electron": [], + "muon": [], + } t = [] @@ -105,11 +126,19 @@ def make_predictions_delphes(model, multi_gpu, test_loader, outpath, device, epo if i == 4999: break - print("Average Inference time per event is: ", round((sum(t) / len(t)), 2), "s") + print( + "Average Inference time per event is: ", + round((sum(t) / len(t)), 2), + "s", + ) t1 = time.time() - print("Time taken to make predictions is:", round(((t1 - t0) / 60), 2), "min") + print( + "Time taken to make predictions is:", + round(((t1 - t0) / 60), 2), + "min", + ) # store the 3 list dictionaries in a list (this is done only to compute the particle multiplicity plots) list = [pred_list, gen_list, cand_list] @@ -164,14 +193,23 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): conf_matrix_mlpf = sklearn.metrics.confusion_matrix(gen_ids.cpu(), pred_ids.cpu(), labels=range(6), normalize="true") plot_confusion_matrix( - conf_matrix_mlpf, target_names, epoch, outpath + "/confusion_matrix_plots/", f"cm_mlpf_epoch_{str(epoch)}" + conf_matrix_mlpf, + target_names, + epoch, + outpath + "/confusion_matrix_plots/", + f"cm_mlpf_epoch_{str(epoch)}", ) # make confusion matrix for rule based PF conf_matrix_cand = sklearn.metrics.confusion_matrix(gen_ids.cpu(), cand_ids.cpu(), labels=range(6), normalize="true") plot_confusion_matrix( - conf_matrix_cand, target_names, epoch, outpath + "/confusion_matrix_plots/", "cm_cand", target="rule-based" + conf_matrix_cand, + target_names, + epoch, + outpath + "/confusion_matrix_plots/", + "cm_cand", + target="rule-based", ) # making all the other plots @@ -269,14 +307,26 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): fig, ax = plt.subplots(1, 1, figsize=(8, 2 * 8)) plot_particle_multiplicity(list_for_multiplicities, "chhadron", ax) - plt.savefig(outpath + "/multiplicity_plots/num_chhadron.png", bbox_inches="tight") - plt.savefig(outpath + "/multiplicity_plots/num_chhadron.pdf", bbox_inches="tight") + plt.savefig( + outpath + "/multiplicity_plots/num_chhadron.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/multiplicity_plots/num_chhadron.pdf", + bbox_inches="tight", + ) plt.close(fig) fig, ax = plt.subplots(1, 1, figsize=(8, 2 * 8)) plot_particle_multiplicity(list_for_multiplicities, "nhadron", ax) - plt.savefig(outpath + "/multiplicity_plots/num_nhadron.png", bbox_inches="tight") - plt.savefig(outpath + "/multiplicity_plots/num_nhadron.pdf", bbox_inches="tight") + plt.savefig( + outpath + "/multiplicity_plots/num_nhadron.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/multiplicity_plots/num_nhadron.pdf", + bbox_inches="tight", + ) plt.close(fig) fig, ax = plt.subplots(1, 1, figsize=(8, 2 * 8)) @@ -287,8 +337,14 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): fig, ax = plt.subplots(1, 1, figsize=(8, 2 * 8)) plot_particle_multiplicity(list_for_multiplicities, "electron", ax) - plt.savefig(outpath + "/multiplicity_plots/num_electron.png", bbox_inches="tight") - plt.savefig(outpath + "/multiplicity_plots/num_electron.pdf", bbox_inches="tight") + plt.savefig( + outpath + "/multiplicity_plots/num_electron.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/multiplicity_plots/num_electron.pdf", + bbox_inches="tight", + ) plt.close(fig) fig, ax = plt.subplots(1, 1, figsize=(8, 2 * 8)) @@ -376,16 +432,40 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): plt.close(fig) fig, (ax2) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 1, "eta", 0.2, ax=ax2, legend_title=sample + "\n") + plot_reso( + ygen, + ypred, + ycand, + 1, + "eta", + 0.2, + ax=ax2, + legend_title=sample + "\n", + ) plt.savefig(outpath + "/resolution_plots/res_pid1_eta.png", bbox_inches="tight") plt.savefig(outpath + "/resolution_plots/res_pid1_eta.pdf", bbox_inches="tight") plt.tight_layout() plt.close(fig) fig, (ax3) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 1, "energy", 0.2, ax=ax3, legend_title=sample + "\n") - plt.savefig(outpath + "/resolution_plots/res_pid1_energy.png", bbox_inches="tight") - plt.savefig(outpath + "/resolution_plots/res_pid1_energy.pdf", bbox_inches="tight") + plot_reso( + ygen, + ypred, + ycand, + 1, + "energy", + 0.2, + ax=ax3, + legend_title=sample + "\n", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid1_energy.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid1_energy.pdf", + bbox_inches="tight", + ) plt.tight_layout() plt.close(fig) @@ -398,16 +478,40 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): plt.close(fig) fig, (ax2) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 2, "eta", 0.2, ax=ax2, legend_title=sample + "\n") + plot_reso( + ygen, + ypred, + ycand, + 2, + "eta", + 0.2, + ax=ax2, + legend_title=sample + "\n", + ) plt.savefig(outpath + "/resolution_plots/res_pid2_eta.png", bbox_inches="tight") plt.savefig(outpath + "/resolution_plots/res_pid2_eta.pdf", bbox_inches="tight") plt.tight_layout() plt.close(fig) fig, (ax3) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 2, "energy", 0.2, ax=ax3, legend_title=sample + "\n") - plt.savefig(outpath + "/resolution_plots/res_pid2_energy.png", bbox_inches="tight") - plt.savefig(outpath + "/resolution_plots/res_pid2_energy.pdf", bbox_inches="tight") + plot_reso( + ygen, + ypred, + ycand, + 2, + "energy", + 0.2, + ax=ax3, + legend_title=sample + "\n", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid2_energy.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid2_energy.pdf", + bbox_inches="tight", + ) plt.tight_layout() plt.close(fig) @@ -420,16 +524,40 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): plt.close(fig) fig, (ax2) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 3, "eta", 0.2, ax=ax2, legend_title=sample + "\n") + plot_reso( + ygen, + ypred, + ycand, + 3, + "eta", + 0.2, + ax=ax2, + legend_title=sample + "\n", + ) plt.savefig(outpath + "/resolution_plots/res_pid3_eta.png", bbox_inches="tight") plt.savefig(outpath + "/resolution_plots/res_pid3_eta.pdf", bbox_inches="tight") plt.tight_layout() plt.close(fig) fig, (ax3) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 3, "energy", 0.2, ax=ax3, legend_title=sample + "\n") - plt.savefig(outpath + "/resolution_plots/res_pid3_energy.png", bbox_inches="tight") - plt.savefig(outpath + "/resolution_plots/res_pid3_energy.pdf", bbox_inches="tight") + plot_reso( + ygen, + ypred, + ycand, + 3, + "energy", + 0.2, + ax=ax3, + legend_title=sample + "\n", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid3_energy.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid3_energy.pdf", + bbox_inches="tight", + ) plt.tight_layout() plt.close(fig) @@ -442,16 +570,40 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): plt.close(fig) fig, (ax2) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 4, "eta", 0.2, ax=ax2, legend_title=sample + "\n") + plot_reso( + ygen, + ypred, + ycand, + 4, + "eta", + 0.2, + ax=ax2, + legend_title=sample + "\n", + ) plt.savefig(outpath + "/resolution_plots/res_pid4_eta.png", bbox_inches="tight") plt.savefig(outpath + "/resolution_plots/res_pid4_eta.pdf", bbox_inches="tight") plt.tight_layout() plt.close(fig) fig, (ax3) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 4, "energy", 0.2, ax=ax3, legend_title=sample + "\n") - plt.savefig(outpath + "/resolution_plots/res_pid4_energy.png", bbox_inches="tight") - plt.savefig(outpath + "/resolution_plots/res_pid4_energy.pdf", bbox_inches="tight") + plot_reso( + ygen, + ypred, + ycand, + 4, + "energy", + 0.2, + ax=ax3, + legend_title=sample + "\n", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid4_energy.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid4_energy.pdf", + bbox_inches="tight", + ) plt.tight_layout() plt.close(fig) @@ -464,16 +616,40 @@ def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): plt.close(fig) fig, (ax2) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 5, "eta", 0.2, ax=ax2, legend_title=sample + "\n") + plot_reso( + ygen, + ypred, + ycand, + 5, + "eta", + 0.2, + ax=ax2, + legend_title=sample + "\n", + ) plt.savefig(outpath + "/resolution_plots/res_pid5_eta.png", bbox_inches="tight") plt.savefig(outpath + "/resolution_plots/res_pid5_eta.pdf", bbox_inches="tight") plt.tight_layout() plt.close(fig) fig, (ax3) = plt.subplots(1, 1, figsize=(8, 8)) - plot_reso(ygen, ypred, ycand, 5, "energy", 0.2, ax=ax3, legend_title=sample + "\n") - plt.savefig(outpath + "/resolution_plots/res_pid5_energy.png", bbox_inches="tight") - plt.savefig(outpath + "/resolution_plots/res_pid5_energy.pdf", bbox_inches="tight") + plot_reso( + ygen, + ypred, + ycand, + 5, + "energy", + 0.2, + ax=ax3, + legend_title=sample + "\n", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid5_energy.png", + bbox_inches="tight", + ) + plt.savefig( + outpath + "/resolution_plots/res_pid5_energy.pdf", + bbox_inches="tight", + ) plt.tight_layout() plt.close(fig) diff --git a/mlpf/pyg/evaluate.py b/mlpf/pyg/evaluate.py index 096db25a6..39de10fdb 100644 --- a/mlpf/pyg/evaluate.py +++ b/mlpf/pyg/evaluate.py @@ -91,17 +91,32 @@ def make_predictions(rank, model, file_loader, batch_size, num_classes, PATH): for key, var in vars.items(): var = var[:padded_num_elem_size] var = torch.nn.functional.pad( - var, (0, 0, 0, padded_num_elem_size - var.shape[0]), mode="constant", value=0 + var, + (0, 0, 0, padded_num_elem_size - var.shape[0]), + mode="constant", + value=0, ).unsqueeze(0) vars_padded[key] = var X.append(vars_padded["X"]) Y_pid.append( torch.cat( - [vars_padded["gen_ids_one_hot"], vars_padded["cand_ids_one_hot"], vars_padded["pred_ids_one_hot"]] + [ + vars_padded["gen_ids_one_hot"], + vars_padded["cand_ids_one_hot"], + vars_padded["pred_ids_one_hot"], + ] + ).unsqueeze(0) + ) + Y_p4.append( + torch.cat( + [ + vars_padded["ygen"], + vars_padded["ycand"], + vars_padded["pred_p4"], + ] ).unsqueeze(0) ) - Y_p4.append(torch.cat([vars_padded["ygen"], vars_padded["ycand"], vars_padded["pred_p4"]]).unsqueeze(0)) outfile = f"{PATH}/predictions/pred_batch{ibatch}_{rank}.pt" print(f"saving predictions at {outfile}") @@ -200,9 +215,17 @@ def flatten(arr): t0 = time.time() torch.save(Xs, f"{pred_path}/post_processed_Xs.pt", pickle_protocol=4) torch.save(X_f, f"{pred_path}/post_processed_X_f.pt", pickle_protocol=4) - torch.save(msk_X_f, f"{pred_path}/post_processed_msk_X_f.pt", pickle_protocol=4) + torch.save( + msk_X_f, + f"{pred_path}/post_processed_msk_X_f.pt", + pickle_protocol=4, + ) torch.save(yvals, f"{pred_path}/post_processed_yvals.pt", pickle_protocol=4) - torch.save(yvals_f, f"{pred_path}/post_processed_yvals_f.pt", pickle_protocol=4) + torch.save( + yvals_f, + f"{pred_path}/post_processed_yvals_f.pt", + pickle_protocol=4, + ) print(f"Time taken to save the predictions is: {round(((time.time() - t0) / 60), 2)} min") return Xs, X_f, msk_X_f, yvals, yvals_f @@ -225,8 +248,22 @@ def make_plots_cms(pred_path, plot_path, sample): # plot distributions print("plot_dist...") plot_dist(yvals_f, "pt", np.linspace(0, 200, 61), r"$p_T$", plot_path, sample) - plot_dist(yvals_f, "energy", np.linspace(0, 2000, 61), r"$E$", plot_path, sample) - plot_dist(yvals_f, "eta", np.linspace(-6, 6, 61), r"$\eta$", plot_path, sample) + plot_dist( + yvals_f, + "energy", + np.linspace(0, 2000, 61), + r"$E$", + plot_path, + sample, + ) + plot_dist( + yvals_f, + "eta", + np.linspace(-6, 6, 61), + r"$\eta$", + plot_path, + sample, + ) # plot cm print("plot_cm...") @@ -235,7 +272,17 @@ def make_plots_cms(pred_path, plot_path, sample): # plot eff_and_fake_rate print("plot_eff_and_fake_rate...") - plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=1, ivar=4, ielem=1, bins=np.logspace(-1, 3, 41), log=True) + plot_eff_and_fake_rate( + X_f, + yvals_f, + plot_path, + sample, + icls=1, + ivar=4, + ielem=1, + bins=np.logspace(-1, 3, 41), + log=True, + ) plot_eff_and_fake_rate( X_f, yvals_f, @@ -246,9 +293,19 @@ def make_plots_cms(pred_path, plot_path, sample): ielem=1, bins=np.linspace(-4, 4, 41), log=False, - xlabel="PFElement $\eta$", + xlabel=r"PFElement $\eta$", + ) + plot_eff_and_fake_rate( + X_f, + yvals_f, + plot_path, + sample, + icls=2, + ivar=4, + ielem=5, + bins=np.logspace(-1, 3, 41), + log=True, ) - plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=2, ivar=4, ielem=5, bins=np.logspace(-1, 3, 41), log=True) plot_eff_and_fake_rate( X_f, yvals_f, @@ -259,9 +316,19 @@ def make_plots_cms(pred_path, plot_path, sample): ielem=5, bins=np.linspace(-5, 5, 41), log=False, - xlabel="PFElement $\eta$", + xlabel=r"PFElement $\eta$", + ) + plot_eff_and_fake_rate( + X_f, + yvals_f, + plot_path, + sample, + icls=5, + ivar=4, + ielem=4, + bins=np.logspace(-1, 2, 41), + log=True, ) - plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=5, ivar=4, ielem=4, bins=np.logspace(-1, 2, 41), log=True) plot_eff_and_fake_rate( X_f, yvals_f, @@ -272,7 +339,7 @@ def make_plots_cms(pred_path, plot_path, sample): ielem=4, bins=np.linspace(-5, 5, 41), log=False, - xlabel="PFElement $\eta$", + xlabel=r"PFElement $\eta$", ) # distribution_icls diff --git a/mlpf/pyg/model.py b/mlpf/pyg/model.py index 1331737e2..2f0e1d91b 100644 --- a/mlpf/pyg/model.py +++ b/mlpf/pyg/model.py @@ -48,7 +48,15 @@ def __init__( self.conv = nn.ModuleList() for i in range(num_convs): - self.conv.append(GravNetConv_MLPF(embedding_dim, embedding_dim, space_dim, propagate_dim, k)) + self.conv.append( + GravNetConv_MLPF( + embedding_dim, + embedding_dim, + space_dim, + propagate_dim, + k, + ) + ) # self.conv.append(GravNetConv_cmspepr(embedding_dim, embedding_dim, space_dim, propagate_dim, k)) # self.conv.append(EdgeConvBlock(embedding_dim, embedding_dim, k)) @@ -138,7 +146,11 @@ def reset_parameters(self): self.lin_p.reset_parameters() self.lin_out.reset_parameters() - def forward(self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: + def forward( + self, + x: Union[Tensor, PairTensor], + batch: Union[OptTensor, Optional[PairTensor]] = None, + ) -> Tensor: is_bipartite: bool = True if isinstance(x, Tensor): @@ -175,7 +187,12 @@ def forward(self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread # message passing - out = self.propagate(edge_index, x=(msg_activations, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) + out = self.propagate( + edge_index, + x=(msg_activations, None), + edge_weight=edge_weight, + size=(s_l.size(0), s_r.size(0)), + ) return self.lin_out(out) @@ -183,7 +200,13 @@ def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: return x_j * edge_weight.unsqueeze(1) def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: - out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum") + out_mean = scatter( + inputs, + index, + dim=self.node_dim, + dim_size=dim_size, + reduce="sum", + ) return out_mean def __repr__(self) -> str: diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index e8a2fba9a..c9ea2a8f3 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -32,14 +32,46 @@ def compute_weights(rank, target_ids, num_classes): @torch.no_grad() -def validation_run(rank, model, train_loader, valid_loader, batch_size, alpha, target_type, num_classes, outpath): +def validation_run( + rank, + model, + train_loader, + valid_loader, + batch_size, + alpha, + target_type, + num_classes, + outpath, +): with torch.no_grad(): optimizer = None - ret = train(rank, model, train_loader, valid_loader, batch_size, optimizer, alpha, target_type, num_classes, outpath) + ret = train( + rank, + model, + train_loader, + valid_loader, + batch_size, + optimizer, + alpha, + target_type, + num_classes, + outpath, + ) return ret -def train(rank, model, train_loader, valid_loader, batch_size, optimizer, alpha, target_type, num_classes, outpath): +def train( + rank, + model, + train_loader, + valid_loader, + batch_size, + optimizer, + alpha, + target_type, + num_classes, + outpath, +): """ A training/validation run over a given epoch that gets called in the training_loop() function. When optimizer is set to None, it freezes the model for a validation_run. @@ -122,7 +154,9 @@ def train(rank, model, train_loader, valid_loader, batch_size, optimizer, alpha, losses_tot = losses_tot + loss_tot.detach() conf_matrix += sklearn.metrics.confusion_matrix( - target_ids.detach().cpu(), pred_ids.detach().cpu(), labels=range(num_classes) + target_ids.detach().cpu(), + pred_ids.detach().cpu(), + labels=range(num_classes), ) # if i == 2: @@ -202,7 +236,16 @@ def training_loop( # training step model.train() losses, conf_matrix_train = train( - rank, model, train_loader, valid_loader, batch_size, optimizer, alpha, target, num_classes, outpath + rank, + model, + train_loader, + valid_loader, + batch_size, + optimizer, + alpha, + target, + num_classes, + outpath, ) losses_clf_train.append(losses["losses_clf"]) @@ -212,7 +255,15 @@ def training_loop( # validation step model.eval() losses, conf_matrix_val = validation_run( - rank, model, train_loader, valid_loader, batch_size, alpha, target, num_classes, outpath + rank, + model, + train_loader, + valid_loader, + batch_size, + alpha, + target, + num_classes, + outpath, ) losses_clf_valid.append(losses["losses_clf"]) @@ -271,8 +322,20 @@ def training_loop( elif data == "cms": target_names = CLASS_NAMES_CMS - plot_confusion_matrix(conf_matrix_train, target_names, epoch + 1, cm_path, f"epoch_{str(epoch)}_cmTrain") - plot_confusion_matrix(conf_matrix_val, target_names, epoch + 1, cm_path, f"epoch_{str(epoch)}_cmValid") + plot_confusion_matrix( + conf_matrix_train, + target_names, + epoch + 1, + cm_path, + f"epoch_{str(epoch)}_cmTrain", + ) + plot_confusion_matrix( + conf_matrix_val, + target_names, + epoch + 1, + cm_path, + f"epoch_{str(epoch)}_cmValid", + ) # make loss plots make_plot_from_lists( diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py index fa84c96cc..a90c42d02 100644 --- a/mlpf/pyg/utils.py +++ b/mlpf/pyg/utils.py @@ -191,7 +191,14 @@ def make_plot_from_lists(title, xaxis, yaxis, save_as, X, Xlabel, X_save_as, out pkl.dump(var, f) -def define_regions(num_eta_regions=10, num_phi_regions=10, max_eta=5, min_eta=-5, max_phi=1.5, min_phi=-1.5): +def define_regions( + num_eta_regions=10, + num_phi_regions=10, + max_eta=5, + min_eta=-5, + max_phi=1.5, + min_phi=-1.5, +): """ Defines regions in (eta,phi) space to make bins within an event and build graphs within these bins. @@ -259,7 +266,10 @@ def batch_event_into_regions(data, regions): ycand = torch.cat([ycand, data.ycand[in_region_msk]]) ycand_id = torch.cat([ycand_id, data.ycand_id[in_region_msk]]) batch = torch.cat( - [batch, region + torch.zeros([len(data.x[in_region_msk])])] + [ + batch, + region + torch.zeros([len(data.x[in_region_msk])]), + ] ) # assumes events were already fed one at a time (i.e. batch_size=1) data = Batch( diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index a2aca631c..ea574fa21 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -108,17 +108,30 @@ def train_ddp(rank, world_size, args, dataset, model, num_classes, outpath): hyper_train = int(args.n_train / world_size) hyper_valid = int(args.n_valid / world_size) - train_dataset = torch.utils.data.Subset(dataset, np.arange(start=rank * hyper_train, stop=(rank + 1) * hyper_train)) + train_dataset = torch.utils.data.Subset( + dataset, + np.arange(start=rank * hyper_train, stop=(rank + 1) * hyper_train), + ) valid_dataset = torch.utils.data.Subset( - dataset, np.arange(start=args.n_train + rank * hyper_valid, stop=args.n_train + (rank + 1) * hyper_valid) + dataset, + np.arange( + start=args.n_train + rank * hyper_valid, + stop=args.n_train + (rank + 1) * hyper_valid, + ), ) # construct file loaders file_loader_train = make_file_loaders( - world_size, train_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor + world_size, + train_dataset, + num_workers=args.num_workers, + prefetch_factor=args.prefetch_factor, ) file_loader_valid = make_file_loaders( - world_size, valid_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor + world_size, + valid_dataset, + num_workers=args.num_workers, + prefetch_factor=args.prefetch_factor, ) # copy the model to the GPU with id=rank @@ -164,11 +177,17 @@ def inference_ddp(rank, world_size, args, dataset, model, num_classes, PATH): # give each gpu a subset of the data hyper_test = int(args.n_test / world_size) - test_dataset = torch.utils.data.Subset(dataset, np.arange(start=rank * hyper_test, stop=(rank + 1) * hyper_test)) + test_dataset = torch.utils.data.Subset( + dataset, + np.arange(start=rank * hyper_test, stop=(rank + 1) * hyper_test), + ) # construct data loaders file_loader_test = make_file_loaders( - world_size, test_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor + world_size, + test_dataset, + num_workers=args.num_workers, + prefetch_factor=args.prefetch_factor, ) # copy the model to the GPU with id=rank @@ -177,7 +196,14 @@ def inference_ddp(rank, world_size, args, dataset, model, num_classes, PATH): model.eval() ddp_model = DDP(model, device_ids=[rank]) - make_predictions(rank, ddp_model, file_loader_test, args.batch_size, num_classes, PATH) + make_predictions( + rank, + ddp_model, + file_loader_test, + args.batch_size, + num_classes, + PATH, + ) cleanup() @@ -195,14 +221,23 @@ def train(device, world_size, args, dataset, model, num_classes, outpath): device = device.index train_dataset = torch.utils.data.Subset(dataset, np.arange(start=0, stop=args.n_train)) - valid_dataset = torch.utils.data.Subset(dataset, np.arange(start=args.n_train, stop=args.n_train + args.n_valid)) + valid_dataset = torch.utils.data.Subset( + dataset, + np.arange(start=args.n_train, stop=args.n_train + args.n_valid), + ) # construct file loaders file_loader_train = make_file_loaders( - world_size, train_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor + world_size, + train_dataset, + num_workers=args.num_workers, + prefetch_factor=args.prefetch_factor, ) file_loader_valid = make_file_loaders( - world_size, valid_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor + world_size, + valid_dataset, + num_workers=args.num_workers, + prefetch_factor=args.prefetch_factor, ) # move the model to the device (cuda or cpu) @@ -244,7 +279,10 @@ def inference(device, world_size, args, dataset, model, num_classes, PATH): # construct data loaders file_loader_test = make_file_loaders( - world_size, test_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor + world_size, + test_dataset, + num_workers=args.num_workers, + prefetch_factor=args.prefetch_factor, ) # copy the model to the GPU with id=rank @@ -309,9 +347,25 @@ def inference(device, world_size, args, dataset, model, num_classes, PATH): dataset = PFGraphDataset(args.dataset, args.data) if world_size >= 2: - run_demo(train_ddp, world_size, args, dataset, model, num_classes, outpath) + run_demo( + train_ddp, + world_size, + args, + dataset, + model, + num_classes, + outpath, + ) else: - train(device, world_size, args, dataset, model, num_classes, outpath) + train( + device, + world_size, + args, + dataset, + model, + num_classes, + outpath, + ) # load the best epoch state state_dict = torch.load(outpath + "/best_epoch_weights.pth", map_location=device) @@ -341,9 +395,25 @@ def inference(device, world_size, args, dataset, model, num_classes, PATH): dataset_test = PFGraphDataset(args.dataset_test, args.data) if world_size >= 2: - run_demo(inference_ddp, world_size, args, dataset_test, model, num_classes, PATH) + run_demo( + inference_ddp, + world_size, + args, + dataset_test, + model, + num_classes, + PATH, + ) else: - inference(device, world_size, args, dataset_test, model, num_classes, PATH) + inference( + device, + world_size, + args, + dataset_test, + model, + num_classes, + PATH, + ) postprocess_predictions(pred_path) diff --git a/mlpf/raytune/utils.py b/mlpf/raytune/utils.py index d9ab6298b..b71978676 100644 --- a/mlpf/raytune/utils.py +++ b/mlpf/raytune/utils.py @@ -29,7 +29,11 @@ def get_raytune_search_alg(raytune_cfg, seeds=False): seed = 1234 else: seed = None - return TuneBOHB(metric=raytune_cfg["default_metric"], mode=raytune_cfg["default_mode"], seed=seed) + return TuneBOHB( + metric=raytune_cfg["default_metric"], + mode=raytune_cfg["default_mode"], + seed=seed, + ) # requires pip install bayesian-optimization if raytune_cfg["search_alg"] == "bayes": @@ -61,7 +65,10 @@ def get_raytune_search_alg(raytune_cfg, seeds=False): import nevergrad as ng return NevergradSearch( - optimizer=ng.optimizers.BayesOptim(pca=False, init_budget=raytune_cfg["nevergrad"]["n_random_steps"]), + optimizer=ng.optimizers.BayesOptim( + pca=False, + init_budget=raytune_cfg["nevergrad"]["n_random_steps"], + ), metric=raytune_cfg["default_metric"], mode=raytune_cfg["default_mode"], ) diff --git a/mlpf/tfmodel/analysis.py b/mlpf/tfmodel/analysis.py index 7ca3e8298..f42e67d4f 100644 --- a/mlpf/tfmodel/analysis.py +++ b/mlpf/tfmodel/analysis.py @@ -13,11 +13,22 @@ def main(): @main.command() @click.help_option("-h", "--help") -@click.option("-p", "--path", help="path to json file or dir containing json files", type=click.Path()) +@click.option( + "-p", + "--path", + help="path to json file or dir containing json files", + type=click.Path(), +) @click.option("-y", "--ylabel", default=None, help="Y-axis label", type=str) @click.option("-x", "--xlabel", default="Step", help="X-axis label", type=str) @click.option("-t", "--title", default=None, help="X-axis label", type=str) -@click.option("-s", "--save_dir", default=None, help="X-axis label", type=click.Path()) +@click.option( + "-s", + "--save_dir", + default=None, + help="X-axis label", + type=click.Path(), +) def plot_cometml_json(path, ylabel, xlabel, title=None, save_dir=None): path = Path(path) @@ -50,9 +61,20 @@ def plot_cometml_json(path, ylabel, xlabel, title=None, save_dir=None): ) ) - pp = plt.plot(metric["x"], metric["y"], label=metric["name"], linestyle="-") + pp = plt.plot( + metric["x"], + metric["y"], + label=metric["name"], + linestyle="-", + ) color = pp[0].get_color() - plt.plot(val_metric["x"], val_metric["y"], label=val_metric["name"], linestyle="--", color=color) + plt.plot( + val_metric["x"], + val_metric["y"], + label=val_metric["name"], + linestyle="--", + color=color, + ) plt.legend() plt.xlabel(xlabel) diff --git a/mlpf/tfmodel/callbacks.py b/mlpf/tfmodel/callbacks.py index cdbf4608b..f290460b0 100644 --- a/mlpf/tfmodel/callbacks.py +++ b/mlpf/tfmodel/callbacks.py @@ -32,7 +32,10 @@ def _collect_learning_rate(self, logs): if hasattr(opt, "lr"): lr_schedule = getattr(opt, "lr", None) - if isinstance(lr_schedule, tf.keras.optimizers.schedules.LearningRateSchedule): + if isinstance( + lr_schedule, + tf.keras.optimizers.schedules.LearningRateSchedule, + ): logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(opt.iterations))) else: logs.update({"learning_rate": np.float64(tf.keras.backend.eval(opt.lr))}) diff --git a/mlpf/tfmodel/datasets/BaseDatasetFactory.py b/mlpf/tfmodel/datasets/BaseDatasetFactory.py index 48c6cdfd2..9cd37f2f1 100644 --- a/mlpf/tfmodel/datasets/BaseDatasetFactory.py +++ b/mlpf/tfmodel/datasets/BaseDatasetFactory.py @@ -113,7 +113,10 @@ def mlpf_dataset_from_config(dataset_name, full_config, split, max_events=None): def get_map_to_supervised(config): target_particles = config["dataset"]["target_particles"] num_output_classes = config["dataset"]["num_output_classes"] - assert target_particles in ["gen", "cand"], "Target particles has to be 'cand' or 'gen'." + assert target_particles in [ + "gen", + "cand", + ], "Target particles has to be 'cand' or 'gen'." def func(data_item): X = data_item["X"] @@ -171,7 +174,12 @@ def interleave_datasets(joint_dataset_name, split, datasets): [ds.tensorflow_dataset for ds in datasets], choice_dataset ) - ds = MLPFDataset(joint_dataset_name, split, interleaved_tensorflow_dataset, sum([ds.num_samples for ds in datasets])) + ds = MLPFDataset( + joint_dataset_name, + split, + interleaved_tensorflow_dataset, + sum([ds.num_samples for ds in datasets]), + ) ds._num_steps = num_steps_total logging.info( "Interleaved joint dataset {}:{} with {} steps, {} samples".format(ds.name, ds.split, ds.num_steps(), ds.num_samples) @@ -198,7 +206,10 @@ def num_steps(self): # In case dynamic batching was applied, we don't know the number of steps for the dataset # compute it using https://stackoverflow.com/a/61019377 self._num_steps = ( - self.tensorflow_dataset.map(lambda *args: 1, num_parallel_calls=tf.data.AUTOTUNE) + self.tensorflow_dataset.map( + lambda *args: 1, + num_parallel_calls=tf.data.AUTOTUNE, + ) .reduce(tf.constant(0), lambda x, _: x + 1) .numpy() ) diff --git a/mlpf/tfmodel/hypertuning.py b/mlpf/tfmodel/hypertuning.py index b7b1c9370..f0e2fdf1c 100644 --- a/mlpf/tfmodel/hypertuning.py +++ b/mlpf/tfmodel/hypertuning.py @@ -31,7 +31,13 @@ def model_builder(hp): config["parameters"]["output_decoding"]["mask_reg_cls0"] = hp.Choice("output_mask_reg_cls0", values=[True, False]) model = make_model(config, dtype="float32") - model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"])) + model.build( + ( + 1, + config["dataset"]["padded_num_elem_size"], + config["dataset"]["num_input_features"], + ) + ) opt = get_optimizer(config, lr_schedule) diff --git a/mlpf/tfmodel/kernel_attention.py b/mlpf/tfmodel/kernel_attention.py index 7bdab6b24..61835fd80 100644 --- a/mlpf/tfmodel/kernel_attention.py +++ b/mlpf/tfmodel/kernel_attention.py @@ -76,7 +76,12 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None): else: raise ValueError('Illegal padding value; must be one of "left"' '"right" or None.') paddings = tf.concat( - [tf.zeros([axis, 2], dtype=tf.int32), axis_paddings, tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0 + [ + tf.zeros([axis, 2], dtype=tf.int32), + axis_paddings, + tf.zeros([rank - axis - 1, 2], dtype=tf.int32), + ], + axis=0, ) return tf.pad(tensor, paddings) @@ -94,11 +99,21 @@ def split_tensor_into_chunks(tensor, axis, chunk_length): """ shape = tf.shape(tensor) num_chunks = shape[axis] // chunk_length - new_shape = tf.concat([shape[:axis], [num_chunks, chunk_length], shape[(axis + 1) :]], axis=0) + new_shape = tf.concat( + [shape[:axis], [num_chunks, chunk_length], shape[(axis + 1) :]], + axis=0, + ) return tf.reshape(tensor, new_shape) -def causal_windowed_performer_attention(query_matrix, key_matrix, value_matrix, chunk_length, window_length, padding=None): +def causal_windowed_performer_attention( + query_matrix, + key_matrix, + value_matrix, + chunk_length, + window_length, + padding=None, +): """Applies windowed causal kernel attention with query, key, value tensors. We partition the T-length input sequence into N chunks, each of chunk_length @@ -159,7 +174,13 @@ def causal_windowed_performer_attention(query_matrix, key_matrix, value_matrix, kp_v = tf.einsum("BNCHD,BNCHO->BNHDO", chunked_key_matrix, chunked_value_matrix) kp_v_cumsum = tf.cumsum(kp_v, axis=-4) - kp_v_winsum = kp_v_cumsum - tf.pad(kp_v_cumsum, [[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length] + kp_v_winsum = ( + kp_v_cumsum + - tf.pad( + kp_v_cumsum, + [[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]], + )[:, :-window_length] + ) numerator = tf.einsum("BNCHD,BNHDO->BNCHO", chunked_query_matrix, kp_v_winsum) k_sum = tf.reduce_sum(chunked_key_matrix, axis=-3) @@ -350,7 +371,11 @@ def expplus( # pylint: disable=g-long-lambda _TRANSFORM_MAP = { - "elu": functools.partial(_generalized_kernel, f=lambda x: tf.keras.activations.elu(x) + 1, h=lambda x: 1), + "elu": functools.partial( + _generalized_kernel, + f=lambda x: tf.keras.activations.elu(x) + 1, + h=lambda x: 1, + ), "relu": functools.partial( _generalized_kernel, # Improve numerical stability and avoid NaNs in some cases by adding @@ -483,7 +508,9 @@ def __init__( self._projection_matrix = None if num_random_features > 0: self._projection_matrix = create_projection_matrix( - self._num_random_features, self._key_dim, tf.constant([self._seed, self._seed + 1]) + self._num_random_features, + self._key_dim, + tf.constant([self._seed, self._seed + 1]), ) self.use_causal_windowed = use_causal_windowed self.causal_chunk_length = causal_chunk_length @@ -581,7 +608,12 @@ def _compute_attention( else: kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value) denominator = 1.0 / ( - tf.einsum("BTNH,BNH->BTN", query_prime, tf.reduce_sum(key_prime, axis=1)) + _NUMERIC_STABLER + tf.einsum( + "BTNH,BNH->BTN", + query_prime, + tf.reduce_sum(key_prime, axis=1), + ) + + _NUMERIC_STABLER ) attention_output = tf.einsum("BTNH,BNDH,BTN->BTND", query_prime, kv, denominator) return attention_output @@ -599,7 +631,9 @@ def _build_from_signature(self, query, value, key=None): bias_constraint=self._bias_constraint, ) self._output_dense_softmax = self._make_output_dense( - self._query_shape.rank - 1, common_kwargs, name="attention_output_softmax" + self._query_shape.rank - 1, + common_kwargs, + name="attention_output_softmax", ) self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout) @@ -638,7 +672,13 @@ def call(self, query, value, key=None, attention_mask=None, training=False): if self._begin_kernel > 0: attention_output_softmax = self._compute_attention( - query[:, : self._begin_kernel], key, value, "identity", True, attention_mask, training + query[:, : self._begin_kernel], + key, + value, + "identity", + True, + attention_mask, + training, ) attention_output_softmax = self._dropout_softmax(attention_output_softmax) attention_output_softmax = self._output_dense_softmax(attention_output_softmax) @@ -657,7 +697,13 @@ def call(self, query, value, key=None, attention_mask=None, training=False): attention_output = tf.concat([attention_output_softmax, attention_output_kernel], axis=1) else: attention_output = self._compute_attention( - query, key, value, self._feature_transform, self._is_short_seq, attention_mask, training + query, + key, + value, + self._feature_transform, + self._is_short_seq, + attention_mask, + training, ) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. diff --git a/mlpf/tfmodel/lr_finder.py b/mlpf/tfmodel/lr_finder.py index f49814cf1..08857837c 100644 --- a/mlpf/tfmodel/lr_finder.py +++ b/mlpf/tfmodel/lr_finder.py @@ -15,7 +15,13 @@ class LRFinder(Callback): paper: https://arxiv.org/pdf/1803.09820.pdf. """ - def __init__(self, start_lr: float = 1e-7, end_lr: float = 1e-2, max_steps: int = 200, smoothing=0.9): + def __init__( + self, + start_lr: float = 1e-7, + end_lr: float = 1e-2, + max_steps: int = 200, + smoothing=0.9, + ): super(LRFinder, self).__init__() self.start_lr, self.end_lr = start_lr, end_lr self.max_steps = max_steps diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index 59b79465e..b64b028b1 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -74,7 +74,13 @@ def pairwise_learnable_dist(A, B, ffn, training=False): shp = tf.shape(A) # stack node feature vectors of src[i], dst[j] into a matrix res[i,j] = (src[i], dst[j]) - mg = tf.meshgrid(tf.range(shp[0]), tf.range(shp[1]), tf.range(shp[2]), tf.range(shp[2]), indexing="ij") + mg = tf.meshgrid( + tf.range(shp[0]), + tf.range(shp[1]), + tf.range(shp[2]), + tf.range(shp[2]), + indexing="ij", + ) inds1 = tf.stack([mg[0], mg[1], mg[2]], axis=-1) inds2 = tf.stack([mg[0], mg[1], mg[3]], axis=-1) res = tf.concat([tf.gather_nd(A, inds1), tf.gather_nd(B, inds2)], axis=-1) # (batch, bin, elem, elem, feat) @@ -106,7 +112,10 @@ def reverse_lsh(bins_split, points_binned_enc, small_graph_opt=False): tf.debugging.assert_shapes( [ (bins_split, ("n_batch", "n_bins", "n_points_bin")), - (points_binned_enc, ("n_batch", "n_bins", "n_points_bin", "n_features")), + ( + points_binned_enc, + ("n_batch", "n_bins", "n_points_bin", "n_features"), + ), ] ) @@ -125,7 +134,11 @@ def multiple_bins(): batch_inds = tf.reshape(tf.repeat(tf.range(batch_dim), n_points), (batch_dim, n_points)) bins_split_flat_batch = tf.stack([batch_inds, bins_split_flat], axis=-1) - ret = tf.scatter_nd(bins_split_flat_batch, points_binned_enc_flat, shape=(batch_dim, n_points, n_features)) + ret = tf.scatter_nd( + bins_split_flat_batch, + points_binned_enc_flat, + shape=(batch_dim, n_points, n_features), + ) return ret # in case of n_bins==1, we can just remove the bin dimension @@ -158,7 +171,10 @@ def __init__(self, num_input_classes): def call(self, X): # X[:, :, 0] - categorical index of the element type - Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=X.dtype) + Xid = tf.cast( + tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), + dtype=X.dtype, + ) # X[:, :, 1:] - all the other non-categorical features Xprop = X[:, :, 1:] @@ -178,7 +194,10 @@ def __init__(self, num_input_classes): def call(self, X): # X[:, :, 0] - categorical index of the element type - Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=X.dtype) + Xid = tf.cast( + tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), + dtype=X.dtype, + ) # X[:, :, 1:] - all the other non-categorical features Xprop = X[:, :, 1:] @@ -218,7 +237,10 @@ def __init__(self, num_input_classes): def call(self, X): # X[:, :, 0] - categorical index of the element type - Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=X.dtype) + Xid = tf.cast( + tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), + dtype=X.dtype, + ) Xpt = tf.expand_dims(tf.math.log(X[:, :, 1] + 1.0), axis=-1) Xe = tf.expand_dims(tf.math.log(X[:, :, 4] + 1.0), axis=-1) @@ -333,9 +355,15 @@ def call(self, inputs): tf.debugging.assert_shapes( [ (x, ("n_batch", "n_bins", "n_points_bin", "num_features")), - (adj, ("n_batch", "n_bins", "n_points_bin", "n_points_bin")), + ( + adj, + ("n_batch", "n_bins", "n_points_bin", "n_points_bin"), + ), (msk, ("n_batch", "n_bins", "n_points_bin", 1)), - (out, ("n_batch", "n_bins", "n_points_bin", self.output_dim)), + ( + out, + ("n_batch", "n_bins", "n_points_bin", self.output_dim), + ), ] ) # tf.print("GHConvDense.call:out", out.shape) @@ -373,7 +401,14 @@ def call(self, inputs): def point_wise_feed_forward_network( - d_model, dff, name, num_layers=1, activation="elu", dtype=tf.dtypes.float32, dim_decrease=False, dropout=0.0 + d_model, + dff, + name, + num_layers=1, + activation="elu", + dtype=tf.dtypes.float32, + dim_decrease=False, + dropout=0.0, ): if regularizer_weight > 0: @@ -403,14 +438,23 @@ def point_wise_feed_forward_network( if dim_decrease: dff = dff // 2 - layers.append(tf.keras.layers.Dense(d_model, dtype=dtype, name="{}_dense_{}".format(name, ilayer + 1))) + layers.append( + tf.keras.layers.Dense( + d_model, + dtype=dtype, + name="{}_dense_{}".format(name, ilayer + 1), + ) + ) return tf.keras.Sequential(layers, name=name) def get_message_layer(config_dict, name): config_dict = config_dict.copy() class_name = config_dict.pop("type") - classes = {"NodeMessageLearnable": NodeMessageLearnable, "GHConvDense": GHConvDense} + classes = { + "NodeMessageLearnable": NodeMessageLearnable, + "GHConvDense": GHConvDense, + } conv_cls = classes[class_name] return conv_cls(name=name, **config_dict) @@ -480,7 +524,10 @@ def call(self, x_msg_binned, msk, training=False): node_proj = self.activation(self.ffn_node(x_msg_binned)) - dm = tf.cast(pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training), x_msg_binned.dtype) + dm = tf.cast( + pairwise_learnable_dist(node_proj, node_proj, self.pair_kernel, training=training), + x_msg_binned.dtype, + ) return dm @@ -488,7 +535,10 @@ def build_kernel_from_conf(kernel_dict, name): kernel_dict = kernel_dict.copy() cls_type = kernel_dict.pop("type") - clss = {"NodePairGaussianKernel": NodePairGaussianKernel, "NodePairTrainableKernel": NodePairTrainableKernel} + clss = { + "NodePairGaussianKernel": NodePairGaussianKernel, + "NodePairTrainableKernel": NodePairTrainableKernel, + } return clss[cls_type](name=name, **kernel_dict) @@ -556,20 +606,30 @@ def dobin(): n_bins = tf.math.floordiv(n_points, self.bin_size) tf.debugging.assert_greater( - n_bins, 0, "number of points (dim 1) must be greater than bin_size={}".format(self.bin_size) + n_bins, + 0, + "number of points (dim 1) must be greater than bin_size={}".format(self.bin_size), ) tf.debugging.assert_equal( tf.math.floormod(n_points, self.bin_size), 0, "number of points (dim 1) must be an integer multiple of bin_size={}".format(self.bin_size), ) - mul = tf.linalg.matmul(x_msg, self.codebook_random_rotations[:, : tf.math.maximum(1, n_bins // 2)]) + mul = tf.linalg.matmul( + x_msg, + self.codebook_random_rotations[:, : tf.math.maximum(1, n_bins // 2)], + ) cmul = tf.concat([mul, -mul], axis=-1) bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk) x_msg_binned = tf.gather(x_msg, bins_split, batch_dims=1) x_features_binned = tf.gather(x_node, bins_split, batch_dims=1) msk_f_binned = tf.gather(msk_f, bins_split, batch_dims=1) - return bins_split, x_msg_binned, x_features_binned, msk_f_binned + return ( + bins_split, + x_msg_binned, + x_features_binned, + msk_f_binned, + ) # if we only have one bin, just add a new dimension # (n_batch, n_points, n_features) -> (n_batch, 1, n_points, n_features) @@ -579,13 +639,28 @@ def nobin(): msk_f_binned = tf.expand_dims(msk_f, axis=1) shp = tf.shape(x_msg_binned) bins_split = tf.zeros([shp[0], shp[1], shp[2]], dtype=tf.int32) - return bins_split, x_msg_binned, x_features_binned, msk_f_binned + return ( + bins_split, + x_msg_binned, + x_features_binned, + msk_f_binned, + ) # put each input item into a bin defined by the argmax output across the LSH embedding if self.small_graph_opt: - bins_split, x_msg_binned, x_features_binned, msk_f_binned = tf.cond(n_bins > 1, dobin, nobin) + ( + bins_split, + x_msg_binned, + x_features_binned, + msk_f_binned, + ) = tf.cond(n_bins > 1, dobin, nobin) else: - bins_split, x_msg_binned, x_features_binned, msk_f_binned = dobin() + ( + bins_split, + x_msg_binned, + x_features_binned, + msk_f_binned, + ) = dobin() # Run the node-to-node kernel (distance computation / graph building / attention) dm = self.kernel(x_msg_binned, msk_f_binned, training=training) @@ -601,10 +676,35 @@ def nobin(): dm = tf.math.multiply(dm, msk_col) tf.debugging.assert_shapes( [ - (x_msg_binned, ("n_batch", "n_bins", "n_points_bin", "n_msg_features")), - (x_features_binned, ("n_batch", "n_bins", "n_points_bin", "n_node_features")), + ( + x_msg_binned, + ( + "n_batch", + "n_bins", + "n_points_bin", + "n_msg_features", + ), + ), + ( + x_features_binned, + ( + "n_batch", + "n_bins", + "n_points_bin", + "n_node_features", + ), + ), (msk_f_binned, ("n_batch", "n_bins", "n_points_bin", 1)), - (dm, ("n_batch", "n_bins", "n_points_bin", "n_points_bin", 1)), + ( + dm, + ( + "n_batch", + "n_bins", + "n_points_bin", + "n_points_bin", + 1, + ), + ), ] ) @@ -804,7 +904,13 @@ def call(self, args, training=False): orig_energy = p if self.regression_use_classification: - X_encoded = tf.concat([X_encoded, tf.cast(tf.stop_gradient(out_id_logits), in_dtype)], axis=-1) + X_encoded = tf.concat( + [ + X_encoded, + tf.cast(tf.stop_gradient(out_id_logits), in_dtype), + ], + axis=-1, + ) pred_eta_corr = self.ffn_eta(X_encoded, training=training) pred_eta_corr = pred_eta_corr * msk_input_outtype @@ -820,7 +926,13 @@ def call(self, args, training=False): X_encoded_energy = tf.concat([X_encoded, X_encoded_energy], axis=-1) if self.regression_use_classification: - X_encoded_energy = tf.concat([X_encoded_energy, tf.cast(tf.stop_gradient(out_id_logits), in_dtype)], axis=-1) + X_encoded_energy = tf.concat( + [ + X_encoded_energy, + tf.cast(tf.stop_gradient(out_id_logits), in_dtype), + ], + axis=-1, + ) pred_energy_corr = self.ffn_energy(X_encoded_energy, training=training) pred_energy_corr = pred_energy_corr * msk_input_outtype @@ -910,7 +1022,9 @@ def __init__(self, *args, **kwargs): if self.do_layernorm: self.layernorm1 = tf.keras.layers.LayerNormalization( - axis=-1, epsilon=1e-6, name=kwargs.get("name") + "_layernorm1" + axis=-1, + epsilon=1e-6, + name=kwargs.get("name") + "_layernorm1", ) # self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01) @@ -932,7 +1046,10 @@ def __init__(self, *args, **kwargs): ) self.message_passing_layers = [ - get_message_layer(self.node_message, "{}_msg_{}".format(kwargs.get("name"), iconv)) + get_message_layer( + self.node_message, + "{}_msg_{}".format(kwargs.get("name"), iconv), + ) for iconv in range(self.num_node_messages) ] self.dropout_layer = None @@ -959,8 +1076,25 @@ def call(self, x, msk, training=False): tf.debugging.assert_shapes( [ (bins_split, ("n_batch", "n_bins", "n_points_bin")), - (x, ("n_batch", "n_bins", "n_points_bin", "n_node_features")), - (dm, ("n_batch", "n_bins", "n_points_bin", "n_points_bin", 1)), + ( + x, + ( + "n_batch", + "n_bins", + "n_points_bin", + "n_node_features", + ), + ), + ( + dm, + ( + "n_batch", + "n_bins", + "n_points_bin", + "n_points_bin", + 1, + ), + ), (msk_f, ("n_batch", "n_bins", "n_points_bin", 1)), ] ) @@ -971,7 +1105,10 @@ def call(self, x, msk, training=False): tf.debugging.assert_shapes( [ (x, ("n_batch", "n_bins", "n_points_bin", "feat_in")), - (x_out, ("n_batch", "n_bins", "n_points_bin", "feat_out")), + ( + x_out, + ("n_batch", "n_bins", "n_points_bin", "feat_out"), + ), ] ) x = x_out @@ -1068,7 +1205,11 @@ def call(self, inputs, training=False): n_points = shp[1] bins_to_pad_to = -tf.math.floordiv(-n_points, self.bin_size) - pad_size = [[0, 0], [0, bins_to_pad_to * self.bin_size - n_points], [0, 0]] + pad_size = [ + [0, 0], + [0, bins_to_pad_to * self.bin_size - n_points], + [0, 0], + ] if self.small_graph_opt: X = tf.cond(bins_to_pad_to > 1, lambda: tf.pad(X, pad_size), lambda: X) @@ -1141,7 +1282,12 @@ def call(self, inputs, training=False): debugging_data["dec_output_reg"] = dec_output_reg ret = self.output_dec( - [X[:, :n_points], dec_output_id[:, :n_points], dec_output_reg[:, :n_points], msk_input[:, :n_points]], + [ + X[:, :n_points], + dec_output_id[:, :n_points], + dec_output_reg[:, :n_points], + msk_input[:, :n_points], + ], training=training, ) @@ -1153,7 +1299,16 @@ def call(self, inputs, training=False): return ret else: return tf.concat( - [ret["cls"], ret["charge"], ret["pt"], ret["eta"], ret["sin_phi"], ret["cos_phi"], ret["energy"]], axis=-1 + [ + ret["cls"], + ret["charge"], + ret["pt"], + ret["eta"], + ret["sin_phi"], + ret["cos_phi"], + ret["energy"], + ], + axis=-1, ) def set_trainable_named(self, layer_names): @@ -1190,7 +1345,11 @@ def __init__(self, *args, **kwargs): ) SEED_KERNELATTENTION += 1 self.ffn = point_wise_feed_forward_network( - self.key_dim, self.key_dim, kwargs.get("name") + "_ffn", num_layers=1, activation="elu" + self.key_dim, + self.key_dim, + kwargs.get("name") + "_ffn", + num_layers=1, + activation="elu", ) self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0") self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1") @@ -1200,7 +1359,16 @@ def call(self, args, training=False): Q, X, mask = args msk_input = tf.expand_dims(tf.cast(mask, X.dtype), -1) X = self.norm1(X, training=training) - attn_output = self.attn(query=Q, value=X, key=X, training=training, attention_mask=mask) * msk_input + attn_output = ( + self.attn( + query=Q, + value=X, + key=X, + training=training, + attention_mask=mask, + ) + * msk_input + ) out1 = self.norm2(X + attn_output, training=training) out2 = self.ffn(out1, training=training) return out2 @@ -1265,7 +1433,13 @@ def __init__( self.key_dim = hiddem_dim - self.ffn = point_wise_feed_forward_network(self.key_dim, self.key_dim, "ffn", num_layers=1, activation="elu") + self.ffn = point_wise_feed_forward_network( + self.key_dim, + self.key_dim, + "ffn", + num_layers=1, + activation="elu", + ) self.encoders = [] for i in range(num_layers_encoder): @@ -1351,7 +1525,16 @@ def call(self, inputs, training=False): return ret else: return tf.concat( - [ret["cls"], ret["charge"], ret["pt"], ret["eta"], ret["sin_phi"], ret["cos_phi"], ret["energy"]], axis=-1 + [ + ret["cls"], + ret["charge"], + ret["pt"], + ret["eta"], + ret["sin_phi"], + ret["cos_phi"], + ret["energy"], + ], + axis=-1, ) # def train_step(self, data): diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index 58d0ff6ec..f18438746 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -49,7 +49,15 @@ def on_epoch_end(self, epoch, logs=None): class CustomCallback(tf.keras.callbacks.Callback): - def __init__(self, outpath, dataset, config, plot_freq=1, horovod_enabled=False, comet_experiment=None): + def __init__( + self, + outpath, + dataset, + config, + plot_freq=1, + horovod_enabled=False, + comet_experiment=None, + ): super(CustomCallback, self).__init__() self.plot_freq = plot_freq self.dataset = dataset @@ -94,9 +102,15 @@ def epoch_end(self, epoch, logs, comet_experiment=None): plot_met_and_ratio(met_data, epoch, cp_dir, comet_experiment) jet_distances = compute_distances( - yvals["jet_gen_to_pred_genpt"], yvals["jet_gen_to_pred_predpt"], yvals["jet_ratio_pred"] + yvals["jet_gen_to_pred_genpt"], + yvals["jet_gen_to_pred_predpt"], + yvals["jet_ratio_pred"], + ) + met_distances = compute_distances( + met_data["gen_met"], + met_data["pred_met"], + met_data["ratio_pred"], ) - met_distances = compute_distances(met_data["gen_met"], met_data["pred_met"], met_data["ratio_pred"]) N_jets = len(awkward.flatten(yvals["jets_gen_pt"])) N_jets_matched_pred = len(yvals["jet_gen_to_pred_genpt"]) @@ -397,7 +411,15 @@ def match_two_jet_collections(jets_coll, name1, name2, jet_match_dr): # Given a model, evaluates it on each batch of the validation dataset # For each batch, save the inputs, the generator-level target, the candidate-level target, and the prediction -def eval_model(model, dataset, config, outdir, jet_ptcut=5.0, jet_match_dr=0.1, verbose=False): +def eval_model( + model, + dataset, + config, + outdir, + jet_ptcut=5.0, + jet_match_dr=0.1, + verbose=False, +): ibatch = 0 @@ -477,7 +499,10 @@ def eval_model(model, dataset, config, outdir, jet_ptcut=5.0, jet_match_dr=0.1, jets_coll[typ] = cluster.inclusive_jets(min_pt=jet_ptcut) if verbose: - print("jets {}".format(typ), awkward.to_numpy(awkward.count(jets_coll[typ].px, axis=1))) + print( + "jets {}".format(typ), + awkward.to_numpy(awkward.count(jets_coll[typ].px, axis=1)), + ) # DeltaR match between genjets and MLPF jets gen_to_pred = match_two_jet_collections(jets_coll, "gen", "pred", jet_match_dr) @@ -491,7 +516,14 @@ def eval_model(model, dataset, config, outdir, jet_ptcut=5.0, jet_match_dr=0.1, print("saving to {}".format(outfile)) awkward.to_parquet( - awkward.Array({"inputs": X, "particles": awkvals, "jets": jets_coll, "matched_jets": matched_jets}), + awkward.Array( + { + "inputs": X, + "particles": awkvals, + "jets": jets_coll, + "matched_jets": matched_jets, + } + ), outfile, ) @@ -505,7 +537,16 @@ def freeze_model(model, config, outdir): def model_output(ret): return tf.concat( - [ret["cls"], ret["charge"], ret["pt"], ret["eta"], ret["sin_phi"], ret["cos_phi"], ret["energy"]], axis=-1 + [ + ret["cls"], + ret["charge"], + ret["pt"], + ret["eta"], + ret["sin_phi"], + ret["cos_phi"], + ret["energy"], + ], + axis=-1, ) full_model = tf.function(lambda x: model_output(model(x, training=False))) diff --git a/mlpf/tfmodel/onecycle_scheduler.py b/mlpf/tfmodel/onecycle_scheduler.py index 3d0dd28d2..6b3ac6e9d 100644 --- a/mlpf/tfmodel/onecycle_scheduler.py +++ b/mlpf/tfmodel/onecycle_scheduler.py @@ -68,7 +68,10 @@ def __init__( self.final_div = final_div self.name = name - phases = [CosineAnnealer(lr_min, lr_max, phase_1_steps), CosineAnnealer(lr_max, final_lr, phase_2_steps)] + phases = [ + CosineAnnealer(lr_min, lr_max, phase_1_steps), + CosineAnnealer(lr_max, final_lr, phase_2_steps), + ] step = 0 phase = 0 @@ -116,7 +119,10 @@ def __init__(self, steps, mom_min=0.85, mom_max=0.95, warmup_ratio=0.3): self.phase = 0 self.step = 0 - self.phases = [CosineAnnealer(mom_max, mom_min, phase_1_steps), CosineAnnealer(mom_min, mom_max, phase_2_steps)] + self.phases = [ + CosineAnnealer(mom_max, mom_min, phase_1_steps), + CosineAnnealer(mom_min, mom_max, phase_2_steps), + ] def _get_opt(self): opt = self.model.optimizer diff --git a/mlpf/tfmodel/timing.py b/mlpf/tfmodel/timing.py index 481e1da31..52492f40f 100644 --- a/mlpf/tfmodel/timing.py +++ b/mlpf/tfmodel/timing.py @@ -47,7 +47,10 @@ print( "Nelem={} mean_time={:.2f} ms stddev_time={:.2f} ms mem_used={:.0f} MB".format( - num_elems, 1000.0 * np.mean(times), 1000.0 * np.std(times), np.max(mem_used) + num_elems, + 1000.0 * np.mean(times), + 1000.0 * np.std(times), + np.max(mem_used), ) ) time.sleep(5) diff --git a/mlpf/tfmodel/utils.py b/mlpf/tfmodel/utils.py index fb83ab8f7..996361d2f 100644 --- a/mlpf/tfmodel/utils.py +++ b/mlpf/tfmodel/utils.py @@ -27,11 +27,24 @@ mlpf_dataset_from_config, ) from tfmodel.model_setup import configure_model_weights, make_model -from tfmodel.onecycle_scheduler import MomentumOneCycleScheduler, OneCycleScheduler +from tfmodel.onecycle_scheduler import ( + MomentumOneCycleScheduler, + OneCycleScheduler, +) @tf.function -def histogram_2d(mask, eta, phi, weights_px, weights_py, eta_range, phi_range, nbins, bin_dtype=tf.float32): +def histogram_2d( + mask, + eta, + phi, + weights_px, + weights_py, + eta_range, + phi_range, + nbins, + bin_dtype=tf.float32, +): eta_bins = tf.histogram_fixed_width_bins(eta, eta_range, nbins=nbins, dtype=bin_dtype) phi_bins = tf.histogram_fixed_width_bins(phi, phi_range, nbins=nbins, dtype=bin_dtype) @@ -51,9 +64,29 @@ def histogram_2d(mask, eta, phi, weights_px, weights_py, eta_range, phi_range, n @tf.function -def batched_histogram_2d(mask, eta, phi, w_px, w_py, x_range, y_range, nbins, bin_dtype=tf.float32): +def batched_histogram_2d( + mask, + eta, + phi, + w_px, + w_py, + x_range, + y_range, + nbins, + bin_dtype=tf.float32, +): return tf.map_fn( - lambda a: histogram_2d(a[0], a[1], a[2], a[3], a[4], x_range, y_range, nbins, bin_dtype), + lambda a: histogram_2d( + a[0], + a[1], + a[2], + a[3], + a[4], + x_range, + y_range, + nbins, + bin_dtype, + ), (mask, eta, phi, w_px, w_py), fn_output_signature=tf.TensorSpec( [nbins, nbins], @@ -109,7 +142,7 @@ def create_experiment_dir(prefix=None, suffix=None): def get_best_checkpoint(train_dir): checkpoint_list = list(Path(Path(train_dir) / "weights").glob("weights*.hdf5")) # Sort the checkpoints according to the loss in their filenames - checkpoint_list.sort(key=lambda x: float(re.search("\d+-\d+.\d+", str(x.name))[0].split("-")[-1])) + checkpoint_list.sort(key=lambda x: float(re.search(r"\d+-\d+.\d+", str(x.name))[0].split("-")[-1])) # Return the checkpoint with smallest loss return str(checkpoint_list[0]) @@ -117,7 +150,7 @@ def get_best_checkpoint(train_dir): def get_latest_checkpoint(train_dir): checkpoint_list = list(Path(Path(train_dir) / "weights").glob("weights*.hdf5")) # Sort the checkpoints according to the epoch number in their filenames - checkpoint_list.sort(key=lambda x: int(re.search("\d+-\d+.\d+", str(x.name))[0].split("-")[0])) + checkpoint_list.sort(key=lambda x: int(re.search(r"\d+-\d+.\d+", str(x.name))[0].split("-")[0])) # Return the checkpoint with highest epoch number return str(checkpoint_list[-1]) @@ -131,7 +164,7 @@ def delete_all_but_best_checkpoint(train_dir, dry_run): raise UserWarning("Couldn't find any checkpoints. No deletion was made.") else: # Sort the checkpoints according to the loss in their filenames - checkpoint_list.sort(key=lambda x: float(re.search("\d+-\d+.\d+", str(x))[0].split("-")[-1])) + checkpoint_list.sort(key=lambda x: float(re.search(r"\d+-\d+.\d+", str(x))[0].split("-")[-1])) best_ckpt = checkpoint_list.pop(0) for ckpt in checkpoint_list: if not dry_run: @@ -260,10 +293,18 @@ def get_optimizer(config, lr_schedule=None): return opt elif config["setup"]["optimizer"] == "adamw": cfg_adamw = config["optimizer"]["adamw"] - return tfa.optimizers.AdamW(learning_rate=lr, weight_decay=cfg_adamw["weight_decay"], amsgrad=cfg_adamw["amsgrad"]) + return tfa.optimizers.AdamW( + learning_rate=lr, + weight_decay=cfg_adamw["weight_decay"], + amsgrad=cfg_adamw["amsgrad"], + ) elif config["setup"]["optimizer"] == "sgd": cfg_sgd = config["optimizer"]["sgd"] - return tf.keras.optimizers.legacy.SGD(learning_rate=lr, momentum=cfg_sgd["momentum"], nesterov=cfg_sgd["nesterov"]) + return tf.keras.optimizers.legacy.SGD( + learning_rate=lr, + momentum=cfg_sgd["momentum"], + nesterov=cfg_sgd["nesterov"], + ) else: raise ValueError( "Only 'adam', 'adamw' and 'sgd' are supported optimizers, got {}".format(config["setup"]["optimizer"]) @@ -332,7 +373,15 @@ def func(X, y, w): return func -def load_and_interleave(joint_dataset_name, dataset_names, config, num_batches_multiplier, split, batch_size, max_events): +def load_and_interleave( + joint_dataset_name, + dataset_names, + config, + num_batches_multiplier, + split, + batch_size, + max_events, +): datasets = [mlpf_dataset_from_config(ds_name, config, split, max_events) for ds_name in dataset_names] ds = interleave_datasets(joint_dataset_name, split, datasets) tensorflow_dataset = ds.tensorflow_dataset.map(get_map_to_supervised(config)) @@ -377,7 +426,13 @@ def load_and_interleave(joint_dataset_name, dataset_names, config, num_batches_m # Load multiple datasets and mix them together -def get_datasets(datasets_to_interleave, config, num_batches_multiplier, split, max_events=None): +def get_datasets( + datasets_to_interleave, + config, + num_batches_multiplier, + split, + max_events=None, +): datasets = [] for joint_dataset_name in datasets_to_interleave.keys(): ds_conf = datasets_to_interleave[joint_dataset_name] @@ -426,7 +481,8 @@ def set_config_loss(config, trainable): def get_class_loss(config): if config["setup"]["classification_loss_type"] == "categorical_cross_entropy": cls_loss = tf.keras.losses.CategoricalCrossentropy( - from_logits=False, label_smoothing=config["setup"].get("classification_label_smoothing", 0.0) + from_logits=False, + label_smoothing=config["setup"].get("classification_label_smoothing", 0.0), ) elif config["setup"]["classification_loss_type"] == "sigmoid_focal_crossentropy": cls_loss = tfa.losses.sigmoid_focal_crossentropy @@ -665,9 +721,26 @@ def get_loss_dict(config): # get the datasets for training, testing and validation def get_train_test_val_datasets(config, num_batches_multiplier, ntrain=None, ntest=None): - ds_train = get_datasets(config["train_test_datasets"], config, num_batches_multiplier, "train", ntrain) - ds_test = get_datasets(config["train_test_datasets"], config, num_batches_multiplier, "test", ntest) - ds_val = mlpf_dataset_from_config(config["validation_dataset"], config, "test", config["validation_num_events"]) + ds_train = get_datasets( + config["train_test_datasets"], + config, + num_batches_multiplier, + "train", + ntrain, + ) + ds_test = get_datasets( + config["train_test_datasets"], + config, + num_batches_multiplier, + "test", + ntest, + ) + ds_val = mlpf_dataset_from_config( + config["validation_dataset"], + config, + "test", + config["validation_num_events"], + ) ds_val.tensorflow_dataset = ds_val.tensorflow_dataset.padded_batch(config["validation_batch_size"]) return ds_train, ds_test, ds_val diff --git a/mlpf/tfmodel/utils_analysis.py b/mlpf/tfmodel/utils_analysis.py index b59a57eec..476c0662a 100644 --- a/mlpf/tfmodel/utils_analysis.py +++ b/mlpf/tfmodel/utils_analysis.py @@ -46,13 +46,21 @@ def plot_ray_analysis(analysis, save=False, skip=0): dfs = analysis.fetch_trial_dataframes() result_df = analysis.dataframe() - for key in tqdm(dfs.keys(), desc="Creating Ray analysis plots", total=len(dfs.keys())): + for key in tqdm( + dfs.keys(), + desc="Creating Ray analysis plots", + total=len(dfs.keys()), + ): result = result_df[result_df["logdir"] == key] fig, axs = plt.subplots(5, 4, figsize=(12, 9), tight_layout=True) for var, ax in zip(to_plot, axs.flat): # Skip first `skip` values so loss plots don't include the very large losses which occur at start of training - ax.plot(dfs[key].index.values[skip:], dfs[key][var][skip:], alpha=0.8) + ax.plot( + dfs[key].index.values[skip:], + dfs[key][var][skip:], + alpha=0.8, + ) ax.set_xlabel("Epoch") ax.set_ylabel(var) ax.grid(alpha=0.3) @@ -175,7 +183,12 @@ def topk_summary_plot_v2(analysis, k, save=False, save_dir=None): fig, axs = plt.subplots(len(to_plot), 1, figsize=(12, 9), tight_layout=True, sharex=True) for var, ax_row in zip(to_plot, axs): for ii, key in enumerate(dd["logdir"]): - ax_row.plot(dfs[key].index.values, dfs[key][var], alpha=0.8, label="#{}".format(ii + 1)) + ax_row.plot( + dfs[key].index.values, + dfs[key][var], + alpha=0.8, + label="#{}".format(ii + 1), + ) ax_row.set_ylabel(var) ax_row.grid(alpha=0.3) ax_row.legend() @@ -217,14 +230,24 @@ def summarize_top_k(analysis, k, save=False, save_dir=None): cm_green = sns.light_palette("green", as_cmap=True) cm_red = sns.light_palette("red", as_cmap=True) - max_is_better = ["cls_acc_unweighted", "val_cls_acc_weighted", "val_cls_acc_unweighted"] + max_is_better = [ + "cls_acc_unweighted", + "val_cls_acc_weighted", + "val_cls_acc_unweighted", + ] min_is_better = ["loss", "cls_loss", "val_loss", "val_cls_loss"] styled_summary = ( summary.style.background_gradient(cmap=cm_green, subset=max_is_better) .background_gradient(cmap=cm_red, subset=min_is_better) - .highlight_max(subset=max_is_better, props="color:black; font-weight:bold; background-color:yellow;") - .highlight_min(subset=min_is_better, props="color:black; font-weight:bold; background-color:yellow;") + .highlight_max( + subset=max_is_better, + props="color:black; font-weight:bold; background-color:yellow;", + ) + .highlight_min( + subset=min_is_better, + props="color:black; font-weight:bold; background-color:yellow;", + ) .set_caption("Top {} trials according to {}".format(k, analysis.default_metric)) .hide_index() ) diff --git a/scripts/plot_nvidiasmi_csv.py b/scripts/plot_nvidiasmi_csv.py index b49ca0ae4..85e8e95d5 100644 --- a/scripts/plot_nvidiasmi_csv.py +++ b/scripts/plot_nvidiasmi_csv.py @@ -11,7 +11,11 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "-d", "--dir", type=str, default="parameters/delphes-gnn-skipconn.yaml", help="dir containing csv files" + "-d", + "--dir", + type=str, + default="parameters/delphes-gnn-skipconn.yaml", + help="dir containing csv files", ) args = parser.parse_args() return args