From 950ce0415aeff09aae187e88336b8ef1ed3a5db9 Mon Sep 17 00:00:00 2001 From: Farouk Mokhtar Date: Fri, 1 Jul 2022 13:24:40 +0200 Subject: [PATCH] optimized pytorch geometric pipeline using DDP (#118) * update the pytorch pipeline --- .gitignore | 15 +- mlpf/lrp/lrp_mlpf.py | 17 +- mlpf/lrp/model.py | 50 +- .../{lrp_pipeline.py => lrp_mlpf_pipeline.py} | 46 +- .../PFGraphDataset.py} | 114 +-- mlpf/pyg/README.md | 59 ++ mlpf/pyg/__init__.py | 17 + mlpf/{pytorch_delphes => pyg}/args.py | 42 +- mlpf/pyg/cms_plots.py | 478 +++++++++ mlpf/pyg/cms_utils.py | 128 +++ .../utils_plots.py => pyg/delphes_plots.py} | 207 ++-- .../evaluate.py => pyg/delphes_utils.py} | 34 +- mlpf/pyg/environment.yml | 20 + mlpf/pyg/evaluate.py | 265 +++++ mlpf/pyg/get_data_cms.sh | 28 + mlpf/pyg/get_data_delphes.sh | 45 + mlpf/pyg/model.py | 295 ++++++ mlpf/pyg/training.py | 279 +++++ mlpf/pyg/utils.py | 320 ++++++ mlpf/pyg_pipeline.py | 316 ++++++ mlpf/pytorch_cms/README.md | 5 - mlpf/pytorch_cms/eval_end2end_cms.py | 67 -- mlpf/pytorch_cms/graph_data_cms.py | 202 ---- mlpf/pytorch_cms/gravnet.py | 106 -- mlpf/pytorch_cms/train_end2end_cms.py | 670 ------------- mlpf/pytorch_delphes/README.md | 23 - mlpf/pytorch_delphes/__init__.py | 10 - mlpf/pytorch_delphes/model.py | 225 ----- mlpf/pytorch_delphes/training.py | 220 ---- mlpf/pytorch_delphes/utils.py | 185 ---- mlpf/pytorch_pipeline.py | 149 --- notebooks/cms-mlpf.ipynb | 172 ++-- notebooks/delphes-lrp-playground.ipynb | 949 ++++++++++++++++++ 33 files changed, 3526 insertions(+), 2232 deletions(-) rename mlpf/{lrp_pipeline.py => lrp_mlpf_pipeline.py} (71%) rename mlpf/{pytorch_delphes/graph_data_delphes.py => pyg/PFGraphDataset.py} (54%) create mode 100644 mlpf/pyg/README.md create mode 100644 mlpf/pyg/__init__.py rename mlpf/{pytorch_delphes => pyg}/args.py (52%) create mode 100644 mlpf/pyg/cms_plots.py create mode 100644 mlpf/pyg/cms_utils.py rename mlpf/{pytorch_delphes/utils_plots.py => pyg/delphes_plots.py} (64%) rename mlpf/{pytorch_delphes/evaluate.py => pyg/delphes_utils.py} (93%) create mode 100644 mlpf/pyg/environment.yml create mode 100644 mlpf/pyg/evaluate.py create mode 100755 mlpf/pyg/get_data_cms.sh create mode 100755 mlpf/pyg/get_data_delphes.sh create mode 100644 mlpf/pyg/model.py create mode 100644 mlpf/pyg/training.py create mode 100644 mlpf/pyg/utils.py create mode 100644 mlpf/pyg_pipeline.py delete mode 100644 mlpf/pytorch_cms/README.md delete mode 100644 mlpf/pytorch_cms/eval_end2end_cms.py delete mode 100644 mlpf/pytorch_cms/graph_data_cms.py delete mode 100644 mlpf/pytorch_cms/gravnet.py delete mode 100644 mlpf/pytorch_cms/train_end2end_cms.py delete mode 100644 mlpf/pytorch_delphes/README.md delete mode 100644 mlpf/pytorch_delphes/__init__.py delete mode 100644 mlpf/pytorch_delphes/model.py delete mode 100644 mlpf/pytorch_delphes/training.py delete mode 100644 mlpf/pytorch_delphes/utils.py delete mode 100644 mlpf/pytorch_pipeline.py create mode 100644 notebooks/delphes-lrp-playground.ipynb diff --git a/.gitignore b/.gitignore index 56c75e84c..a378a4529 100644 --- a/.gitignore +++ b/.gitignore @@ -4,21 +4,18 @@ *.pt *.pdf data/* +experiments/* +prp/* *.pth test/__pycache__/ -mlpf/pytorch/__pycache__/* -mlpf/plotting/__pycache__/* -mlpf/pytorch/data -test_tmp/ -test_tmp_delphes/ +*/__pycache__/* .DS_Store -prp *.pyc *.pyo -mlpf/updated/LRP/pid* -mlpf/updated/LRP/class* - *.ipynb_checkpoints + +*playground.py +nohup.out diff --git a/mlpf/lrp/lrp_mlpf.py b/mlpf/lrp/lrp_mlpf.py index 51da79dd2..466dc1283 100644 --- a/mlpf/lrp/lrp_mlpf.py +++ b/mlpf/lrp/lrp_mlpf.py @@ -14,14 +14,15 @@ class LRP_MLPF(): """ - A class that act on graph datasets and GNNs based on the Gravnet layer (e.g. the MLPF model) - The main trick is to realize that the ".lin_s" layers in Gravnet are irrelevant for explanations so shall be skipped - The hack, however, is to substitute them precisely with the message_passing step + A class that introduces useful functionality to perform layerwise-relevance propagation (LRP) on MLPF. + This class is meant to work for any model with any number of GravNetConv layers. + The main trick is to realize that the ".lin_s" layers in GravNetConv are irrelevant for explanations so shall be skipped. + The hack, however, is to substitute them precisely with the message_passing step. Differences from standard LRP - - Rscores become tensors/graphs of input features per output neuron instead of vectors - - accomodates message passing steps by using the adjacency matrix as the weight matrix in standard LRP, - and redistributing Rscores over the other dimension (over nodes instead of features) + a. Rscores become tensors/graphs of input features per output neuron instead of vectors + b. accomodates message passing steps by using the adjacency matrix as the weight matrix in standard LRP, + and redistributing Rscores over the other dimension (over nodes instead of features) """ def __init__(self, device, model, epsilon): @@ -153,8 +154,8 @@ def eps_rule(self, layer, layer_name, x, R_tensor_old, neuron_to_explain, msg_pa Can accomodate message_passing layers if the adjacency matrix and the activations before the message_passing are provided. The trick (or as we like to call it, the message_passing hack) is in - (1) using the adjacency matrix as the weight matrix in the standard lrp rule - (2) transposing the activations to distribute the Rscores over the other dimension (over nodes instead of features) + a. using the adjacency matrix as the weight matrix in the standard lrp rule + b. transposing the activations to distribute the Rscores over the other dimension (over nodes instead of features) Args: layer: a torch.nn module with a corresponding weight matrix W diff --git a/mlpf/lrp/model.py b/mlpf/lrp/model.py index 72c4e3ed7..0cb9c4ac5 100644 --- a/mlpf/lrp/model.py +++ b/mlpf/lrp/model.py @@ -103,51 +103,11 @@ def forward(self, batch): class GravNetConv_LRP(MessagePassing): """ - Copied from pytorch_geometric source code - Edits: - - retrieve adjacency matrix (we call A), and the activations before the message passing step (we call msg_activations) - - switched the execution of self.lin_s & self.lin_p so that the message passing step can substitute out of the box self.lin_s for lrp purposes - - used reduce='sum' instead of reduce='mean' in the message passing - - removed skip connection - - The GravNet operator from the `"Learning Representations of Irregular - Particle-detector Geometry with Distance-weighted Graph - Networks" `_ paper, where the graph is - dynamically constructed using nearest neighbors. - The neighbors are constructed in a learnable low-dimensional projection of - the feature space. - A second projection of the input feature space is then propagated from the - neighbors to each vertex using distance weights that are derived by - applying a Gaussian function to the distances. - - Args: - in_channels (int): Size of each input sample, or :obj:`-1` to derive - the size from the first input(s) to the forward method. - out_channels (int): The number of output channels. - space_dimensions (int): The dimensionality of the space used to - construct the neighbors; referred to as :math:`S` in the paper. - propagate_dimensions (int): The number of features to be propagated - between the vertices; referred to as :math:`F_{\textrm{LR}}` in the - paper. - k (int): The number of nearest neighbors. - num_workers (int): Number of workers to use for k-NN computation. - Has no effect in case :obj:`batch` is not :obj:`None`, or the input - lies on the GPU. (default: :obj:`1`) - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - - Shapes: - - **input:** - node features :math:`(|\mathcal{V}|, F_{in})` or - :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` - if bipartite, - batch vector :math:`(|\mathcal{V}|)` or - :math:`((|\mathcal{V}_s|), (|\mathcal{V}_t|))` if bipartite - *(optional)* - - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or - :math:`(|\mathcal{V}_t|, F_{out})` if bipartite - - + Copied from pytorch_geometric source code, with the following edits + a. used reduce='sum' instead of reduce='mean' in the message passing + b. removed skip connection + c. retrieved adjacency matrix and the activations before the message passing, both are useful only for LRP purposes + d. switched the execution of self.lin_s & self.lin_p so that the message passing step can substitute out of the box self.lin_s for lrp purposes """ def __init__(self, in_channels: int, out_channels: int, diff --git a/mlpf/lrp_pipeline.py b/mlpf/lrp_mlpf_pipeline.py similarity index 71% rename from mlpf/lrp_pipeline.py rename to mlpf/lrp_mlpf_pipeline.py index 807e0a02c..61e9840f3 100644 --- a/mlpf/lrp_pipeline.py +++ b/mlpf/lrp_mlpf_pipeline.py @@ -1,11 +1,11 @@ -from pytorch_delphes import PFGraphDataset, dataloader_qcd, load_model +from pyg import PFGraphDataset, dataloader_qcd, load_model from lrp import MLPF, LRP_MLPF, make_Rmaps + import argparse import pickle as pkl import os.path as osp import os import sys -from glob import glob import numpy as np import mplhep as hep @@ -13,13 +13,6 @@ import torch import torch_geometric -from torch_geometric.nn import GravNetConv - -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Sequential as Seq, Linear as Lin, ReLU -from sklearn.metrics import accuracy_score -import matplotlib.pyplot as plt from torch_geometric.data import Data, DataLoader, DataListLoader, Batch @@ -28,13 +21,13 @@ parser = argparse.ArgumentParser() # for saving the model -parser.add_argument("--dataset_qcd", type=str, default='../data/test_tmp_delphes/data/pythia8_qcd', help="testing dataset path") -parser.add_argument("--outpath", type=str, default='../data/test_tmp_delphes/experiments/', help="path to the trained model directory") -parser.add_argument("--load_model", type=str, default="", help="Which model to load") -parser.add_argument("--load_epoch", type=int, default=0, help="Which epoch of the model to load") -parser.add_argument("--out_neuron", type=int, default=0, help="the output neuron you wish to explain") -parser.add_argument("--pid", type=str, default="chhadron", help="Which model to load") -parser.add_argument("--n_test", type=int, default=50, help="number of data files to use for testing.. each file contains 100 events") +parser.add_argument("--dataset_qcd", type=str, default='../data/delphes/pythia8_qcd', help="testing dataset path") +parser.add_argument("--outpath", type=str, default='../experiments/', help="path to the trained model directory") +parser.add_argument("--load_model", type=str, default="", help="Which model to load") +parser.add_argument("--load_epoch", type=int, default=0, help="Which epoch of the model to load") +parser.add_argument("--out_neuron", type=int, default=0, help="the output neuron you wish to explain") +parser.add_argument("--pid", type=str, default="chhadron", help="Which model to load") +parser.add_argument("--n_test", type=int, default=50, help="number of data files to use for testing.. each file contains 100 events") parser.add_argument("--run_lrp", dest='run_lrp', action='store_true', help="runs lrp") parser.add_argument("--make_rmaps", dest='make_rmaps', action='store_true', help="makes rmaps") @@ -42,13 +35,6 @@ if __name__ == "__main__": - """ - e.g. to run lrp and make Rmaps - python -u lrp_pipeline.py --run_lrp --make_rmaps --load_model='MLPF_gen_ntrain_1_nepochs_1_clf_reg' --load_epoch=0 --n_test=1 --pid='chhadron' - - e.g. to only make Rmaps - python -u lrp_pipeline.py --make_rmaps --load_model='MLPF_gen_ntrain_1_nepochs_1_clf_reg' --load_epoch=0 --n_test=1 --out_neuron=0 --pid='chhadron' - """ if args.run_lrp: # Check if the GPU configuration and define the global base device @@ -71,24 +57,26 @@ model = MLPF(**model_kwargs) model.load_state_dict(state_dict) model.to(device) + model.eval() - # run lrp + # initialize placeholders for Rscores, the event inputs, and the event predictions Rtensors_list, preds_list, inputs_list = [], [], [] + # define the lrp instance + lrp_instance = LRP_MLPF(device, model, epsilon=1e-9) + + # loop over events to explain them for i, event in enumerate(loader): print(f'Explaining event # {i}') - # run lrp on sample model - model.eval() - lrp_instance = LRP_MLPF(device, model, epsilon=1e-9) + # run lrp on the event Rtensor, pred, input = lrp_instance.explain(event, neuron_to_explain=args.out_neuron) + # store the Rscores, the event inputs, and the event predictions Rtensors_list.append(Rtensor.detach().to('cpu')) preds_list.append(pred.detach().to('cpu')) inputs_list.append(input.detach().to('cpu').to_dict()) - break - with open(f'{outpath}/Rtensors_list.pkl', 'wb') as f: pkl.dump(Rtensors_list, f) with open(f'{outpath}/inputs_list.pkl', 'wb') as f: diff --git a/mlpf/pytorch_delphes/graph_data_delphes.py b/mlpf/pyg/PFGraphDataset.py similarity index 54% rename from mlpf/pytorch_delphes/graph_data_delphes.py rename to mlpf/pyg/PFGraphDataset.py index 68d2efb9f..a91db2aba 100644 --- a/mlpf/pytorch_delphes/graph_data_delphes.py +++ b/mlpf/pyg/PFGraphDataset.py @@ -1,3 +1,13 @@ +try: + from pyg.cms_utils import prepare_data_cms +except: + from cms_utils import prepare_data_cms + +from numpy.lib.recfunctions import append_fields +import bz2 +import h5py +import pandas +import pandas as pd import numpy as np import os import os.path as osp @@ -10,23 +20,6 @@ import pickle import multiprocessing -# assumes pkl files exist in /test_tmp_delphes/data/pythia8_ttbar/raw -# they are processed and saved as pt files in /test_tmp_delphes/data/pythia8_ttbar/processed -# PFGraphDataset -> returns for 1 event: Data(x=[5139, 12], ycand=[5139, 6], ycand_id=[5139, 6], ygen=[5139, 6], ygen_id=[5139, 6]) - - -def one_hot_embedding(labels, num_classes): - """ - Embedding labels to one-hot form. - Args: - labels: (LongTensor) class labels, sized [N,]. - num_classes: (int) number of classes. - Returns: - (tensor) encoded labels, sized [N, #classes]. - """ - y = torch.eye(num_classes) - return y[labels] - def process_func(args): self, fns, idx_file = args @@ -46,9 +39,10 @@ class PFGraphDataset(Dataset): root (str): path """ - def __init__(self, root, transform=None, pre_transform=None): + def __init__(self, root, data, transform=None, pre_transform=None): super(PFGraphDataset, self).__init__(root, transform, pre_transform) self._processed_dir = Dataset.processed_dir.fget(self) + self.data = data @property def raw_file_names(self): @@ -79,41 +73,40 @@ def download(self): pass def process_single_file(self, raw_file_name): - with open(osp.join(self.raw_dir, raw_file_name), "rb") as fi: - data = pickle.load(fi, encoding='iso-8859-1') - - x = [] - ygen = [] - ycand = [] - d = [] - batch_data = [] - ygen_id = [] - ycand_id = [] - - for i in range(len(data['X'])): - x.append(torch.tensor(data['X'][i], dtype=torch.float)) - ygen.append(torch.tensor(data['ygen'][i], dtype=torch.float)) - ycand.append(torch.tensor(data['ycand'][i], dtype=torch.float)) - - # one-hot encoding the first element in ygen & ycand (which is the PID) and store it in ygen_id & ycand_id - ygen_id.append(ygen[i][:, 0]) - ycand_id.append(ycand[i][:, 0]) - - ygen_id[i] = ygen_id[i].long() - ycand_id[i] = ycand_id[i].long() - - ygen_id[i] = one_hot_embedding(ygen_id[i], 6) - ycand_id[i] = one_hot_embedding(ycand_id[i], 6) - - # remove from ygen & ycand the first element (PID) so that they only contain the regression variables - d = Data( - x=x[i], - ygen=ygen[i][:, 1:], ygen_id=ygen_id[i], - ycand=ycand[i][:, 1:], ycand_id=ycand_id[i] - ) - - batch_data.append(d) - return batch_data + """ + Loads a list of 100 events from a pkl file and generates pytorch geometric Data() objects and stores them in .pt format. + For cms data, each element is assumed to be a dict('Xelem', 'ygen', ycand') of numpy rec_arrays with the first element in ygen/ycand is the pid + For delphes data, each element is assumed to be a dict('X', 'ygen', ycand') of numpy standard arrays with the first element in ygen/ycand is the pid + + Args + raw_file_name: a pkl file + Returns + batched_data: a list of Data() objects of the form + cms ~ Data(x=[#elem, 41], ygen=[#elem, 6], ygen_id=[#elem, 9], ycand=[#elem, 6], ycand_id=[#elem, 9]) + delphes ~ Data(x=[#elem, 12], ygen=[#elem, 6], ygen_id=[#elem, 6], ycand=[#elem, 6], ycand_id=[#elem, 6]) + """ + + if self.data == 'cms': + return prepare_data_cms(osp.join(self.raw_dir, raw_file_name)) + + elif self.data == 'delphes': + # load the data pkl file + with open(osp.join(self.raw_dir, raw_file_name), "rb") as fi: + data = pickle.load(fi, encoding='iso-8859-1') + + for i in range(len(data['X'])): + # remove from ygen & ycand the first element (PID) so that they only contain the regression variables + d = Data( + x=torch.tensor(data['X'][i], dtype=torch.float), + ygen=torch.tensor(data['ygen'][i], dtype=torch.float)[:, 1:], + ygen_id=torch.tensor(data['ygen'][i], dtype=torch.float)[:, 0].long(), + ycand=torch.tensor(data['ycand'][i], dtype=torch.float)[:, 1:], + ycand_id=torch.tensor(data['ycand'][i], dtype=torch.float)[:, 0].long(), + ) + + batched_data.append(d) + + return batched_data def process_multiple_files(self, filenames, idx_file): datas = [self.process_single_file(fn) for fn in filenames] @@ -139,7 +132,7 @@ def process_parallel(self, num_files_to_batch, num_proc): def get(self, idx): p = osp.join(self.processed_dir, 'data_{}.pt'.format(idx)) - data = torch.load(p) + data = torch.load(p, map_location='cpu') return data def __getitem__(self, idx): @@ -149,6 +142,7 @@ def __getitem__(self, idx): def parse_args(): import argparse parser = argparse.ArgumentParser() + parser.add_argument("--data", type=str, required=True, help="'cms' or 'delphes'?") parser.add_argument("--dataset", type=str, required=True, help="Input data path") parser.add_argument("--processed_dir", type=str, help="processed", required=False, default=None) parser.add_argument("--num-files-merge", type=int, default=10, help="number of files to merge") @@ -159,12 +153,20 @@ def parse_args(): if __name__ == "__main__": + """ + e.g. to run for cms + python PFGraphDataset.py --data cms --dataset ../../data/cms/TTbar_14TeV_TuneCUETP8M1_cfi --processed_dir ../../data/cms/TTbar_14TeV_TuneCUETP8M1_cfi/processed --num-files-merge 1 --num-proc 1 + + e.g. to run for delphes + python3 PFGraphDataset.py --data delphes --dataset $sample --processed_dir $sample/processed --num-files-merge 1 --num-proc 1 + + """ + args = parse_args() - pfgraphdataset = PFGraphDataset(root=args.dataset) + pfgraphdataset = PFGraphDataset(root=args.dataset, data=args.data) if args.processed_dir: pfgraphdataset._processed_dir = args.processed_dir pfgraphdataset.process_parallel(args.num_files_merge, args.num_proc) - # pfgraphdataset.process(args.num_files_merge) diff --git a/mlpf/pyg/README.md b/mlpf/pyg/README.md new file mode 100644 index 000000000..3f064c376 --- /dev/null +++ b/mlpf/pyg/README.md @@ -0,0 +1,59 @@ +# Setup + +Have conda installed. +```bash +conda env create -f environment.yml +conda activate mlpf +``` + +# Training + +### DELPHES training +The dataset is available from zenodo: https://doi.org/10.5281/zenodo.4452283. + +To download and process the full DELPHES dataset: +```bash +./get_data_delphes.sh +``` + +This script will download and process the data under a directory called `data/delphes` under `particleflow`. + +To perform a quick training on the dataset: +```bash +cd ../ +python -u pyg_pipeline.py --data delphes --dataset= --dataset_qcd= +``` + +To load a pretrained model which is stored in a directory under `particleflow/experiments` for evaluation: +```bash +cd ../ +python -u pyg_pipeline.py --data delphes --load --load_model= --load_epoch= --dataset= --dataset_qcd= +``` + +### CMS training + +To download and process the full CMS dataset: +```bash +./get_data_cms.sh +``` +This script will download and process the data under a directory called `data/cms` under `particleflow`. + +To perform a quick training on the dataset: +```bash +cd ../ +python -u pyg_pipeline.py --data cms --dataset= --dataset_qcd= +``` + +To load a pretrained model which is stored in a directory under `particleflow/experiments` for evaluation: +```bash +cd ../ +python -u pyg_pipeline.py --data cms --load --load_model= --load_epoch= --dataset= --dataset_qcd= +``` + +### XAI and LRP studies on MLPF + +You must have a pre-trained model under `particleflow/experiments`: +```bash +cd ../ +python -u lrp_mlpf_pipeline.py --run_lrp --make_rmaps --load_model= --load_epoch= +``` diff --git a/mlpf/pyg/__init__.py b/mlpf/pyg/__init__.py new file mode 100644 index 000000000..958e0ed27 --- /dev/null +++ b/mlpf/pyg/__init__.py @@ -0,0 +1,17 @@ +from pyg.args import parse_args +from pyg.PFGraphDataset import PFGraphDataset +from pyg.utils import one_hot_embedding, save_model, load_model +from pyg.utils import make_plot_from_lists +from pyg.utils import features_delphes, features_cms, target_p4 +from pyg.utils import make_file_loaders, dataloader_ttbar, dataloader_qcd + +from pyg.cms_utils import prepare_data_cms, CLASS_NAMES_CMS, CLASS_NAMES_CMS_LATEX +from pyg.delphes_plots import pid_to_name_delphes, name_to_pid_delphes, pid_to_name_cms, name_to_pid_cms + +from pyg.model import MLPF + +from pyg.training import training_loop +from pyg.evaluate import make_predictions, postprocess_predictions, make_plots_cms + +from pyg.cms_plots import plot_numPFelements, plot_met, plot_sum_energy, plot_sum_pt, plot_energy_res, plot_eta_res, plot_multiplicity +from pyg.cms_plots import plot_dist, plot_cm, plot_eff_and_fake_rate, distribution_icls diff --git a/mlpf/pytorch_delphes/args.py b/mlpf/pyg/args.py similarity index 52% rename from mlpf/pytorch_delphes/args.py rename to mlpf/pyg/args.py index 6e336b854..593246a61 100644 --- a/mlpf/pytorch_delphes/args.py +++ b/mlpf/pyg/args.py @@ -5,40 +5,50 @@ def parse_args(): parser = argparse.ArgumentParser() + # for data loading + parser.add_argument("--num_workers", type=int, default=2, help="number of subprocesses used for data loading") + parser.add_argument("--prefetch_factor", type=int, default=4, help="number of samples loaded in advance by each worker") + # for saving the model - parser.add_argument("--outpath", type=str, default='../data/test_tmp_delphes/experiments/', help="output folder") + parser.add_argument("--outpath", type=str, default='../experiments/', help="output folder") + parser.add_argument("--model_prefix", type=str, default='MLPF_model', help="directory to hold the model and all plots under args.outpath/args.model_prefix") # for loading the data - parser.add_argument("--dataset", type=str, default='../data/test_tmp_delphes/data/pythia8_ttbar', help="training dataset path") - parser.add_argument("--dataset_qcd", type=str, default='../data/test_tmp_delphes/data/pythia8_qcd', help="testing dataset path") + parser.add_argument("--data", type=str, required=True, help="cms or delphes?") + parser.add_argument("--dataset", type=str, default='../data/delphes/pythia8_ttbar', help="training dataset path") + parser.add_argument("--dataset_test", type=str, default='../data/delphes/pythia8_qcd', help="testing dataset path") + parser.add_argument("--sample", type=str, default='QCD', help="sample to test on") parser.add_argument("--n_train", type=int, default=1, help="number of data files to use for training.. each file contains 100 events") parser.add_argument("--n_valid", type=int, default=1, help="number of data files to use for validation.. each file contains 100 events") parser.add_argument("--n_test", type=int, default=1, help="number of data files to use for testing.. each file contains 100 events") - parser.add_argument("--title", type=str, default=None, help="Appends this title to the model's name") parser.add_argument("--overwrite", dest='overwrite', action='store_true', help="Overwrites the model if True") # for loading a pre-trained model parser.add_argument("--load", dest='load', action='store_true', help="Load the model (no training)") - parser.add_argument("--load_model", type=str, default="", help="Which model to load") - parser.add_argument("--load_epoch", type=int, default=0, help="Which epoch of the model to load for evaluation") + parser.add_argument("--load_epoch", type=int, default=-1, help="Which epoch of the model to load for evaluation") # for training hyperparameters - parser.add_argument("--n_epochs", type=int, default=1, help="number of training epochs") - parser.add_argument("--batch_size", type=int, default=1, help="Number of .pt files to load in parallel") - parser.add_argument("--patience", type=int, default=100, help="patience before early stopping") + parser.add_argument("--n_epochs", type=int, default=3, help="number of training epochs") + parser.add_argument("--batch_size", type=int, default=1, help="number of events to run inference on before updating the loss") + parser.add_argument("--patience", type=int, default=30, help="patience before early stopping") parser.add_argument("--target", type=str, default="gen", choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", ) parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") - parser.add_argument("--alpha", type=float, default=2e-4, help="loss = clf + alpha*reg.. if set to 0 model only does trains for classification") + parser.add_argument("--alpha", type=float, default=2e-5, help="loss = clf + alpha*reg.. if set to 0 model only does trains for classification") + parser.add_argument("--batch_events", dest='batch_events', action='store_true', help="batches the event in eta,phi space to build the graphs") # for model architecture - parser.add_argument("--hidden_dim1", type=int, default=120, help="hidden dimension of layers before the graph convolutions") - parser.add_argument("--hidden_dim2", type=int, default=256, help="hidden dimension of layers after the graph convolutions") - parser.add_argument("--embedding_dim", type=int, default=64, help="encoded element dimension") - parser.add_argument("--num_convs", type=int, default=2, help="number of graph convolutions") + parser.add_argument("--hidden_dim1", type=int, default=126, help="hidden dimension of layers before the graph convolutions") + parser.add_argument("--hidden_dim2", type=int, default=256, help="hidden dimension of layers after the graph convolutions") + parser.add_argument("--embedding_dim", type=int, default=32, help="encoded element dimension") + parser.add_argument("--num_convs", type=int, default=3, help="number of graph convolutions") parser.add_argument("--space_dim", type=int, default=4, help="Spatial dimension for clustering in gravnet layer") - parser.add_argument("--propagate_dim", type=int, default=22, help="The number of features to be propagated between the vertices") - parser.add_argument("--nearest", type=int, default=16, help="k nearest neighbors in gravnet layer") + parser.add_argument("--propagate_dim", type=int, default=8, help="The number of features to be propagated between the vertices") + parser.add_argument("--nearest", type=int, default=4, help="k nearest neighbors in gravnet layer") + + # for testing the model + parser.add_argument("--make_predictions", dest='make_predictions', action='store_true', help="run inference on the test data") + parser.add_argument("--make_plots", dest='make_plots', action='store_true', help="makes plots of the test predictions") args = parser.parse_args() diff --git a/mlpf/pyg/cms_plots.py b/mlpf/pyg/cms_plots.py new file mode 100644 index 000000000..7bc34f281 --- /dev/null +++ b/mlpf/pyg/cms_plots.py @@ -0,0 +1,478 @@ +from pyg.cms_utils import CLASS_LABELS_CMS, CLASS_NAMES_CMS, CLASS_NAMES_CMS_LATEX + +import pickle +import bz2 + +import numpy as np +from numpy.lib.recfunctions import append_fields + +import torch +import torch_geometric +import torch_geometric.utils +from torch_geometric.data import Dataset, Data, Batch + +import pandas as pd +import json +import glob +import matplotlib.pyplot as plt +import numpy as np + +import sklearn +import sklearn.metrics +import matplotlib +import scipy +import mplhep + +import pandas +import itertools +import mplhep +import boost_histogram as bh +mplhep.style.use(mplhep.styles.CMS) + + +def cms_label(ax, x0=0.01, x1=0.15, x2=0.98, y=0.94): + plt.figtext(x0, y, 'CMS', fontweight='bold', wrap=True, horizontalalignment='left', transform=ax.transAxes) + plt.figtext(x1, y, 'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left', transform=ax.transAxes) + plt.figtext(x2, y, 'Run 3 (14 TeV)', wrap=False, horizontalalignment='right', transform=ax.transAxes) + +# def cms_label_sample_label(x0=0.12, x1=0.23, x2=0.67, y=0.90): +# plt.figtext(x0, y,'CMS',fontweight='bold', wrap=True, horizontalalignment='left') +# plt.figtext(x1, y,'Simulation Preliminary', style='italic', wrap=True, horizontalalignment='left') +# plt.figtext(x2, y,'Run 3 (14 TeV), $\mathrm{t}\overline{\mathrm{t}}$ events', wrap=False, horizontalalignment='left') + + +def sample_label(sample, ax, additional_text="", x=0.01, y=0.87): + if sample == 'QCD': + plt.text(x, y, "QCD events" + additional_text, ha="left", transform=ax.transAxes) + else: + plt.text(x, y, "$\mathrm{t}\overline{\mathrm{t}}$ events" + additional_text, ha="left", transform=ax.transAxes) + + +def apply_thresholds_f(ypred_raw_f, thresholds): + msk = np.ones_like(ypred_raw_f) + for i in range(len(thresholds)): + msk[:, i + 1] = ypred_raw_f[:, i + 1] > thresholds[i] + ypred_id_f = np.argmax(ypred_raw_f * msk, axis=-1) + +# best_2 = np.partition(ypred_raw_f, -2, axis=-1)[..., -2:] +# diff = np.abs(best_2[:, -1] - best_2[:, -2]) +# ypred_id_f[diff<0.05] = 0 + + return ypred_id_f + + +def apply_thresholds(ypred_raw, thresholds): + msk = np.ones_like(ypred_raw) + for i in range(len(thresholds)): + msk[:, :, i + 1] = ypred_raw[:, :, i + 1] > thresholds[i] + ypred_id = np.argmax(ypred_raw * msk, axis=-1) + +# best_2 = np.partition(ypred_raw, -2, axis=-1)[..., -2:] +# diff = np.abs(best_2[:, :, -1] - best_2[:, :, -2]) +# ypred_id[diff<0.05] = 0 + + return ypred_id + + +class_names = {k: v for k, v in zip(CLASS_LABELS_CMS, CLASS_NAMES_CMS)} + + +def plot_numPFelements(X, outpath, sample): + plt.figure() + ax = plt.axes() + plt.hist(np.sum(X[:, :, 0] != 0, axis=1), bins=100) + plt.axvline(6400, ls="--", color="black") + plt.xlabel("number of input PFElements") + plt.ylabel("number of events / bin") + cms_label(ax) + sample_label(sample, ax) + plt.savefig(f"{outpath}/num_PFelements.pdf", bbox_inches="tight") + plt.close() + + +def plot_met(X, yvals, outpath, sample): + sum_px = np.sum(yvals["gen_px"], axis=1) + sum_py = np.sum(yvals["gen_py"], axis=1) + gen_met = np.sqrt(sum_px**2 + sum_py**2)[:, 0] + + sum_px = np.sum(yvals["cand_px"], axis=1) + sum_py = np.sum(yvals["cand_py"], axis=1) + cand_met = np.sqrt(sum_px**2 + sum_py**2)[:, 0] + + sum_px = np.sum(yvals["pred_px"], axis=1) + sum_py = np.sum(yvals["pred_py"], axis=1) + pred_met = np.sqrt(sum_px**2 + sum_py**2)[:, 0] + + fig = plt.figure() + ax = plt.axes() + b = np.linspace(-2, 5, 101) + vals_a = (cand_met - gen_met) / gen_met + vals_b = (pred_met - gen_met) / gen_met + plt.hist(vals_a, bins=b, histtype="step", lw=2, label="PF, $\mu={:.2f}$, $\sigma={:.2f}$".format(np.mean(vals_a), np.std(vals_a))) + plt.hist(vals_b, bins=b, histtype="step", lw=2, label="MLPF, $\mu={:.2f}$, $\sigma={:.2f}$".format(np.mean(vals_b), np.std(vals_b))) + plt.yscale("log") + cms_label(ax) + sample_label(sample, ax) + plt.ylim(10, 1e3) + plt.legend(loc=(0.4, 0.7)) + plt.xlabel(r"$\frac{\mathrm{MET}_{\mathrm{reco}} - \mathrm{MET}_{\mathrm{gen}}}{\mathrm{MET}_{\mathrm{gen}}}$") + plt.ylabel("Number of events / bin") + plt.savefig(f"{outpath}/met.pdf", bbox_inches="tight") + plt.close() + + +def plot_sum_energy(X, yvals, outpath, sample): + fig = plt.figure() + ax = plt.axes() + + plt.scatter( + np.sum(yvals["gen_energy"], axis=1), + np.sum(yvals["cand_energy"], axis=1), + alpha=0.5, + label="PF" + ) + plt.scatter( + np.sum(yvals["gen_energy"], axis=1), + np.sum(yvals["pred_energy"], axis=1), + alpha=0.5, + label="MLPF" + ) + plt.plot([10000, 80000], [10000, 80000], color="black") + plt.legend(loc=4) + cms_label(ax) + sample_label(sample, ax) + plt.xlabel("Gen $\sum E$ [GeV]") + plt.ylabel("Reconstructed $\sum E$ [GeV]") + plt.savefig(f"{outpath}/sum_energy.pdf", bbox_inches="tight") + plt.close() + + +def plot_sum_pt(X, yvals, outpath, sample): + + fig = plt.figure() + ax = plt.axes() + + plt.scatter( + np.sum(yvals["gen_pt"], axis=1), + np.sum(yvals["cand_pt"], axis=1), + alpha=0.5, + label="PF" + ) + plt.scatter( + np.sum(yvals["gen_pt"], axis=1), + np.sum(yvals["pred_pt"], axis=1), + alpha=0.5, + label="PF" + ) + plt.plot([1000, 6000], [1000, 6000], color="black") + plt.legend(loc=4) + cms_label(ax) + sample_label(sample, ax) + plt.xlabel("Gen $\sum p_T$ [GeV]") + plt.ylabel("Reconstructed $\sum p_T$ [GeV]") + plt.savefig(f"{outpath}/sum_pt.pdf", bbox_inches="tight") + plt.close() + + +def plot_energy_res(X, yvals_f, pid, b, ylim, outpath, sample): + + fig = plt.figure() + ax = plt.axes() + + msk = (yvals_f["gen_cls_id"] == pid) & (yvals_f["cand_cls_id"] == pid) & (yvals_f["pred_cls_id"] == pid) + vals_gen = yvals_f["gen_energy"][msk] + vals_cand = yvals_f["cand_energy"][msk] + vals_mlpf = yvals_f["pred_energy"][msk] + + reso_1 = (vals_cand - vals_gen) / vals_gen + reso_2 = (vals_mlpf - vals_gen) / vals_gen + plt.hist(reso_1, bins=b, histtype="step", lw=2, label="PF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_1), np.std(reso_1))) + plt.hist(reso_2, bins=b, histtype="step", lw=2, label="MLPF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_2), np.std(reso_2))) + plt.yscale("log") + plt.xlabel(r"$\frac{E_\mathrm{reco} - E_\mathrm{gen}}{E_\mathrm{gen}}$") + plt.ylabel("Number of particles / bin") + cms_label(ax) + sample_label(sample, ax, f", {CLASS_NAMES_CMS_LATEX[pid]}") + plt.legend(loc=(0.4, 0.7)) + plt.ylim(1, ylim) + plt.savefig(f"{outpath}/energy_res_{CLASS_NAMES_CMS[pid]}.pdf", bbox_inches="tight") + plt.close() + + +def plot_eta_res(X, yvals_f, pid, ylim, outpath, sample): + + fig = plt.figure() + ax = plt.axes() + + msk = (yvals_f["gen_cls_id"] == pid) & (yvals_f["cand_cls_id"] == pid) & (yvals_f["pred_cls_id"] == pid) + vals_gen = yvals_f["gen_eta"][msk] + vals_cand = yvals_f["cand_eta"][msk] + vals_mlpf = yvals_f["pred_eta"][msk] + + b = np.linspace(-10, 10, 100) + + reso_1 = (vals_cand - vals_gen) + reso_2 = (vals_mlpf - vals_gen) + plt.hist(reso_1, bins=b, histtype="step", lw=2, label="PF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_1), np.std(reso_1))) + plt.hist(reso_2, bins=b, histtype="step", lw=2, label="MLPF, $\mu={:.2f}, \sigma={:.2f}$".format(np.mean(reso_2), np.std(reso_2))) + plt.yscale("log") + plt.xlabel(r"$\eta_\mathrm{reco} - \eta_\mathrm{gen}$") + plt.ylabel("Number of particles / bin") + cms_label(ax) + sample_label(sample, ax, f", {CLASS_NAMES_CMS_LATEX[pid]}") + plt.legend(loc=(0.0, 0.7)) + plt.ylim(1, ylim) + plt.savefig(f"{outpath}/eta_res_{CLASS_NAMES_CMS[pid]}.pdf", bbox_inches="tight") + plt.close() + + +def plot_multiplicity(X, yvals, outpath, sample): + for icls in range(1, 8): + # Plot the particle multiplicities + npred = np.sum(yvals["pred_cls_id"] == icls, axis=1) + ngen = np.sum(yvals["gen_cls_id"] == icls, axis=1) + ncand = np.sum(yvals["cand_cls_id"] == icls, axis=1) + fig = plt.figure() + ax = plt.axes() + plt.scatter(ngen, ncand, marker=".", alpha=0.4, label="PF") + plt.scatter(ngen, npred, marker=".", alpha=0.4, label="MLPF") + a = 0.5 * min(np.min(npred), np.min(ngen)) + b = 1.5 * max(np.max(npred), np.max(ngen)) + plt.xlim(a, b) + plt.ylim(a, b) + plt.plot([a, b], [a, b], color="black", ls="--") + plt.xlabel("number of truth particles") + plt.ylabel("number of reconstructed particles") + plt.legend(loc=4) + cms_label(ax) + sample_label(sample, ax, ", " + CLASS_NAMES_CMS[icls]) + plt.savefig(f"{outpath}/num_cls{icls}.pdf", bbox_inches="tight") + plt.close() + + # Plot the sum of particle energies + msk = yvals["gen_cls_id"][:, :, 0] == icls + vals_gen = np.sum(np.ma.MaskedArray(yvals["gen_energy"], ~msk), axis=1)[:, 0] + msk = yvals["pred_cls_id"][:, :, 0] == icls + vals_pred = np.sum(np.ma.MaskedArray(yvals["pred_energy"], ~msk), axis=1)[:, 0] + msk = yvals["cand_cls_id"][:, :, 0] == icls + vals_cand = np.sum(np.ma.MaskedArray(yvals["cand_energy"], ~msk), axis=1)[:, 0] + fig = plt.figure() + ax = plt.axes() + plt.scatter(vals_gen, vals_cand, alpha=0.2, label="PF") + plt.scatter(vals_gen, vals_pred, alpha=0.2, label="MLPF") + minval = min(np.min(vals_gen), np.min(vals_cand), np.min(vals_pred)) + maxval = max(np.max(vals_gen), np.max(vals_cand), np.max(vals_pred)) + plt.plot([minval, maxval], [minval, maxval], color="black") + plt.xlim(minval, maxval) + plt.ylim(minval, maxval) + plt.xlabel("true $\sum E$ [GeV]") + plt.xlabel("reconstructed $\sum E$ [GeV]") + plt.legend(loc=4) + cms_label(ax) + sample_label(sample, ax, f", {CLASS_NAMES_CMS_LATEX[icls]}") + plt.savefig(f"{outpath}/energy_cls{icls}.pdf", bbox_inches="tight") + plt.close() + + +def get_distribution(yvals_f, prefix, bins, var): + + hists = [] + for pid in [13, 11, 22, 1, 2, 130, 211]: + icls = CLASS_LABELS_CMS.index(pid) + msk_pid = (yvals_f[prefix + "_cls_id"] == icls) + h = bh.Histogram(bh.axis.Variable(bins)) + d = yvals_f[prefix + "_" + var][msk_pid] + h.fill(d.flatten()) + hists.append(h) + return hists + + +def plot_dist(yvals_f, var, bin, label, outpath, sample): + + hists_gen = get_distribution(yvals_f, "gen", bin, var) + hists_cand = get_distribution(yvals_f, "cand", bin, var) + hists_pred = get_distribution(yvals_f, "pred", bin, var) + + plt.figure() + ax = plt.axes() + v1 = mplhep.histplot([h[bh.rebin(2)] for h in hists_gen], stack=True, label=[class_names[k] for k in [13, 11, 22, 1, 2, 130, 211]], lw=1) + v2 = mplhep.histplot([h[bh.rebin(2)] for h in hists_pred], stack=True, color=[x.stairs.get_edgecolor() for x in v1], lw=2, histtype="errorbar") + + legend1 = plt.legend(v1, [x.legend_artist.get_label() for x in v1], loc=(0.60, 0.44), title="true") + legend2 = plt.legend(v2, [x.legend_artist.get_label() for x in v1], loc=(0.8, 0.44), title="pred") + plt.gca().add_artist(legend1) + plt.ylabel("Total number of particles / bin") + cms_label(ax) + sample_label(sample, ax) + + plt.yscale("log") + plt.ylim(top=1e9) + plt.xlabel(f"PFCandidate {label} [GeV]") + plt.savefig(f"{outpath}/pfcand_{var}.pdf", bbox_inches="tight") + plt.close() + + +def plot_eff_and_fake_rate(X_f, yvals_f, outpath, sample, + icls=1, + ivar=4, + ielem=1, + bins=np.linspace(-3, 6, 100), + xlabel="PFElement log[E/GeV]", log=True + ): + + values = X_f[:, ivar] + + hist_gen = np.histogram(values[(yvals_f["gen_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_gen_pred = np.histogram(values[(yvals_f["gen_cls_id"] == icls) & (yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_gen_cand = np.histogram(values[(yvals_f["gen_cls_id"] == icls) & (yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + + hist_pred = np.histogram(values[(yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_cand = np.histogram(values[(yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_pred_fake = np.histogram(values[(yvals_f["gen_cls_id"] != icls) & (yvals_f["pred_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + hist_cand_fake = np.histogram(values[(yvals_f["gen_cls_id"] != icls) & (yvals_f["cand_cls_id"] == icls) & (X_f[:, 0] == ielem)], bins=bins) + + eff_mlpf = hist_gen_pred[0] / hist_gen[0] + eff_pf = hist_gen_cand[0] / hist_gen[0] + fake_pf = hist_cand_fake[0] / hist_cand[0] + fake_mlpf = hist_pred_fake[0] / hist_pred[0] + + plt.figure() + ax = plt.axes() + mplhep.histplot(hist_gen, label="Gen", color="black") + mplhep.histplot(hist_cand, label="PF") + mplhep.histplot(hist_pred, label="MLPF") + plt.ylabel("Number of PFElements / bin") + plt.xlabel(xlabel) + if ivar == 3: # eta + plt.xlim(-6, 6) + cms_label(ax) + sample_label(sample, ax, ", " + CLASS_NAMES_CMS[icls]) + plt.legend(loc=(0.75, 0.65)) + if log: + plt.xscale("log") + plt.savefig(f"{outpath}/distr_icls{icls}_ivar{ivar}.pdf", bbox_inches="tight") + plt.close() + + plt.figure() + ax = plt.axes(sharex=ax) + mplhep.histplot(eff_pf, bins=hist_gen[1], label="PF") + mplhep.histplot(eff_mlpf, bins=hist_gen[1], label="MLPF") + plt.ylim(0, 1.2) + plt.ylabel("Efficiency") + plt.xlabel(xlabel) + cms_label(ax) + sample_label(sample, ax, ", " + CLASS_NAMES_CMS[icls]) + plt.legend(loc=(0.75, 0.75)) + if log: + plt.xscale("log") + plt.savefig(f"{outpath}/eff_icls{icls}_ivar{ivar}.pdf", bbox_inches="tight") + plt.close() + + plt.figure() + ax = plt.axes(sharex=ax) + mplhep.histplot(fake_pf, bins=hist_gen[1], label="PF") + mplhep.histplot(fake_mlpf, bins=hist_gen[1], label="MLPF") + plt.ylim(0, 1.2) + plt.ylabel("Fake rate") + plt.xlabel(xlabel) + cms_label(ax) + sample_label(sample, ax, ", " + CLASS_NAMES_CMS[icls]) + plt.legend(loc=(0.75, 0.75)) + if log: + plt.xscale("log") + plt.savefig(f"{outpath}/fake_icls{icls}_ivar{ivar}.pdf", bbox_inches="tight") + plt.close() + + #mplhep.histplot(fake, bins=hist_gen[1], label="fake rate", color="red") +# plt.legend(frameon=False) +# plt.ylim(0,1.4) +# plt.xlabel(xlabel) +# plt.ylabel("Fraction of particles / bin") + + +def plot_cm(yvals_f, msk_X_f, label, outpath): + + fig = plt.figure(figsize=(12, 12)) + ax = plt.axes() + + if label == 'MLPF': + Y = yvals_f["pred_cls_id"][msk_X_f] + else: + Y = yvals_f["cand_cls_id"][msk_X_f] + + cm_norm = sklearn.metrics.confusion_matrix( + yvals_f["gen_cls_id"][msk_X_f], + Y, + labels=range(0, len(CLASS_LABELS_CMS)), + normalize="true" + ) + + plt.imshow(cm_norm, cmap="Blues", origin="lower") + plt.colorbar() + + thresh = cm_norm.max() / 1.5 + for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])): + plt.text(j, i, "{:0.2f}".format(cm_norm[i, j]), + horizontalalignment="center", + color="white" if cm_norm[i, j] > thresh else "black", fontsize=12) + + cms_label(ax, y=1.01) + #cms_label_sample_label(x1=0.18, x2=0.52, y=0.82) + plt.xticks(range(len(CLASS_NAMES_CMS_LATEX)), CLASS_NAMES_CMS_LATEX, rotation=45) + plt.yticks(range(len(CLASS_NAMES_CMS_LATEX)), CLASS_NAMES_CMS_LATEX) + + plt.xlabel(f"{label} candidate ID") + plt.ylabel("Truth ID") + #plt.ylim(-0.5, 6.9) + #plt.title("MLPF trained on PF") + plt.savefig(f"{outpath}/cm_normed_{label}.pdf", bbox_inches="tight") + plt.close() + + +def distribution_icls(yvals_f, outpath): + for icls in range(0, 8): + fig, axs = plt.subplots( + 2, 2, + figsize=(2 * mplhep.styles.CMS["figure.figsize"][0], 2 * mplhep.styles.CMS["figure.figsize"][1]) + ) + + for ax, ivar in zip(axs.flatten(), ["pt", "energy", "eta", "phi"]): + + plt.sca(ax) + + if icls == 0: + vals_true = yvals_f["gen_" + ivar][yvals_f["gen_cls_id"] != 0] + vals_pf = yvals_f["cand_" + ivar][yvals_f["cand_cls_id"] != 0] + vals_pred = yvals_f["pred_" + ivar][yvals_f["pred_cls_id"] != 0] + else: + vals_true = yvals_f["gen_" + ivar][yvals_f["gen_cls_id"] == icls] + vals_pf = yvals_f["cand_" + ivar][yvals_f["cand_cls_id"] == icls] + vals_pred = yvals_f["pred_" + ivar][yvals_f["pred_cls_id"] == icls] + + if ivar == "pt" or ivar == "energy": + b = np.logspace(-3, 4, 61) + log = True + else: + b = np.linspace(np.min(vals_true), np.max(vals_true), 41) + log = False + + plt.hist(vals_true, bins=b, histtype="step", lw=2, label="gen", color="black") + plt.hist(vals_pf, bins=b, histtype="step", lw=2, label="PF") + plt.hist(vals_pred, bins=b, histtype="step", lw=2, label="MLPF") + plt.legend(loc=(0.75, 0.75)) + + ylim = ax.get_ylim() + + cls_name = CLASS_NAMES_CMS[icls] if icls > 0 else "all" + plt.xlabel(f"{cls_name} {ivar}") + + plt.yscale("log") + plt.ylim(10, 10 * ylim[1]) + + if log: + plt.xscale("log") + cms_label(ax) + + plt.tight_layout() + plt.savefig(f"{outpath}/distribution_icls{icls}.pdf", bbox_inches="tight") + plt.close() diff --git a/mlpf/pyg/cms_utils.py b/mlpf/pyg/cms_utils.py new file mode 100644 index 000000000..79e42cd87 --- /dev/null +++ b/mlpf/pyg/cms_utils.py @@ -0,0 +1,128 @@ +import pickle +import bz2 + +import numpy as np +from numpy.lib.recfunctions import append_fields + +import torch +import torch_geometric +import torch_geometric.utils +from torch_geometric.data import Dataset, Data, Batch + + +"""Based on https://github.com/jpata/hep_tfds/blob/master/heptfds/cms_utils.py#L10""" + +# https://github.com/ahlinist/cmssw/blob/1df62491f48ef964d198f574cdfcccfd17c70425/DataFormats/ParticleFlowReco/interface/PFBlockElement.h#L33 +ELEM_LABELS_CMS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +ELEM_NAMES_CMS = ["NONE", "TRACK", "PS1", "PS2", "ECAL", "HCAL", "GSF", "BREM", "HFEM", "HFHAD", "SC", "HO"] + +# https://github.com/cms-sw/cmssw/blob/master/DataFormats/ParticleFlowCandidate/src/PFCandidate.cc#L254 +CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13, 15] +CLASS_NAMES_CMS_LATEX = ["none", "chhad", "nhad", "HFEM", "HFHAD", "$\gamma$", "$e^\pm$", "$\mu^\pm$", r"$\tau$"] +CLASS_NAMES_CMS = ["none", "chhad", "nhad", "HFEM", "HFHAD", "gamma", "ele", "mu", "tau"] + +CLASS_NAMES_LONG_CMS = ["none" "charged hadron", "neutral hadron", "hfem", "hfhad", "photon", "electron", "muon", "tau"] + +CMS_PF_CLASS_NAMES = ["none" "charged hadron", "neutral hadron", "hfem", "hfhad", "photon", "electron", "muon"] + +X_FEATURES = [ + "typ_idx", "pt", "eta", "phi", "e", + "layer", "depth", "charge", "trajpoint", + "eta_ecal", "phi_ecal", "eta_hcal", "phi_hcal", "muon_dt_hits", "muon_csc_hits", "muon_type", + "px", "py", "pz", "deltap", "sigmadeltap", + "gsf_electronseed_trkorecal", + "gsf_electronseed_dnn1", + "gsf_electronseed_dnn2", + "gsf_electronseed_dnn3", + "gsf_electronseed_dnn4", + "gsf_electronseed_dnn5", + "num_hits", "cluster_flags", "corr_energy", + "corr_energy_err", "vx", "vy", "vz", "pterror", "etaerror", "phierror", "lambd", "lambdaerror", "theta", "thetaerror" +] + +Y_FEATURES = [ + "typ_idx", + "charge", + "pt", + "eta", + "sin_phi", + "cos_phi", + "e", +] + + +def prepare_data_cms(fn): + + batched_data = [] + + if fn.endswith(".pkl"): + data = pickle.load(open(fn, "rb"), encoding="iso-8859-1") + elif fn.endswith(".pkl.bz2"): + data = pickle.load(bz2.BZ2File(fn, "rb")) + + for event in data: + Xelem = event["Xelem"] + ygen = event["ygen"] + ycand = event["ycand"] + + # remove PS and BREM from inputs + msk_ps = (Xelem["typ"] == 2) | (Xelem["typ"] == 3) | (Xelem["typ"] == 7) + + Xelem = Xelem[~msk_ps] + ygen = ygen[~msk_ps] + ycand = ycand[~msk_ps] + + Xelem = append_fields( + Xelem, "typ_idx", np.array([ELEM_LABELS_CMS.index(int(i)) for i in Xelem["typ"]], dtype=np.float32) + ) + ygen = append_fields( + ygen, "typ_idx", np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32) + ) + ycand = append_fields( + ycand, + "typ_idx", + np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32), + ) + + Xelem_flat = np.stack( + [ + Xelem[k].view(np.float32).data + for k in X_FEATURES + ], + axis=-1, + ) + ygen_flat = np.stack( + [ + ygen[k].view(np.float32).data + for k in Y_FEATURES + ], + axis=-1, + ) + ycand_flat = np.stack( + [ + ycand[k].view(np.float32).data + for k in Y_FEATURES + ], + axis=-1, + ) + + # take care of outliers + Xelem_flat[np.isnan(Xelem_flat)] = 0 + Xelem_flat[np.abs(Xelem_flat) > 1e4] = 0 + + ygen_flat[np.isnan(ygen_flat)] = 0 + ygen_flat[np.abs(ygen_flat) > 1e4] = 0 + + ycand_flat[np.isnan(ycand_flat)] = 0 + ycand_flat[np.abs(ycand_flat) > 1e4] = 0 + + d = Data( + x=torch.tensor(Xelem_flat), + ygen=torch.tensor(ygen_flat[:, 1:]), + ygen_id=torch.tensor(ygen_flat[:, 0]).long(), + ycand=torch.tensor(ycand_flat[:, 1:]), + ycand_id=torch.tensor(ycand_flat[:, 0]).long(), + ) + batched_data.append(d) + + return batched_data diff --git a/mlpf/pytorch_delphes/utils_plots.py b/mlpf/pyg/delphes_plots.py similarity index 64% rename from mlpf/pytorch_delphes/utils_plots.py rename to mlpf/pyg/delphes_plots.py index c494731cb..6fd77dcdc 100644 --- a/mlpf/pytorch_delphes/utils_plots.py +++ b/mlpf/pyg/delphes_plots.py @@ -11,7 +11,7 @@ plt.style.use(hep.style.ROOT) -pid_names = { +pid_to_name_delphes = { 0: "Null", 1: "Charged hadrons", 2: "Neutral hadrons", @@ -19,14 +19,37 @@ 4: "Electrons", 5: "Muons", } -key_to_pid = { - "null": 0, - "chhadron": 1, - "nhadron": 2, - "photon": 3, - "electron": 4, - "muon": 5, -} + +pid_to_name_cms = {0: 'null', + 1: 'chhadron', + 2: 'nhadron', + 3: 'HFHAD', + 4: 'HFEM', + 5: 'photon', + 6: 'ele', + 7: 'mu', + 8: 'tau', + } + +name_to_pid_delphes = {'null': 0, + 'chhadron': 1, + 'nhadron': 2, + 'photon': 3, + 'ele': 4, + 'mu': 5, + } + +name_to_pid_cms = {'null': 0, + 'chhadron': 1, + 'nhadron': 2, + 'HFHAD': 3, + 'HFEM': 4, + 'photon': 5, + 'ele': 6, + 'mu': 7, + 'tau': 8, + } + var_names = { "pt": r"$p_\mathrm{T}$ [GeV]", "eta": r"$\eta$", @@ -47,13 +70,22 @@ "energy": "E", } var_indices = { - "pt": 2, - "eta": 3, - "sphi": 4, - "cphi": 5, - "energy": 6 + "charge": 0, + "pt": 1, + "eta": 2, + "sphi": 3, + "cphi": 4, + "energy": 5, } +bins = {'charge': np.linspace(0, 5, 100), + 'pt': np.linspace(0, 5, 100), + 'eta': np.linspace(-5, 5, 100), + 'sin phi': np.linspace(-2, 2, 100), + 'cos phi': np.linspace(-2, 2, 100), + 'E': np.linspace(-1, 5, 100), + } + def midpoints(x): return x[:-1] + np.diff(x) / 2 @@ -73,12 +105,17 @@ def divide_zero(a, b): return out -def plot_distribution(pid, target, mlpf, var_name, rng, target_type, fname, legend_title=""): +def plot_distribution(data, pid, target, mlpf, var_name, rng, target_type, fname, legend_title=""): """ plot distributions for the target and mlpf of a given feature for a given PID """ plt.style.use(hep.style.CMS) + if data == 'delphes': + pid_to_name = pid_to_name_delphes + elif data == 'cms': + pid_to_name = pid_to_name_cms + fig = plt.figure(figsize=(10, 10)) if target_type == 'cand': @@ -90,90 +127,66 @@ def plot_distribution(pid, target, mlpf, var_name, rng, target_type, fname, lege plt.xlabel(var_name) if pid != -1: - plt.legend(frameon=False, title=legend_title + pid_names[pid]) + plt.legend(frameon=False, title=legend_title + pid_to_name[pid]) else: plt.legend(frameon=False, title=legend_title) plt.ylim(0, 1.5) - plt.savefig(fname + '.png') + plt.savefig(fname + '.pdf') plt.close(fig) return fig -def plot_distributions_pid(pid, true_id, true_p4, pred_id, pred_p4, pf_id, cand_p4, target, epoch, outpath, legend_title=""): +def plot_distributions_pid(data, pid, true_id, true_p4, pred_id, pred_p4, pf_id, cand_p4, target, epoch, outpath, legend_title=""): """ - plot distributions for the target and mlpf of all features for a given PID + plot distributions for the target and mlpf of the regressed features for a given PID """ plt.style.use("default") - ch_true = true_p4[true_id == pid, 0].flatten().detach().cpu().numpy() - ch_pred = pred_p4[pred_id == pid, 0].flatten().detach().cpu().numpy() - - pt_true = true_p4[true_id == pid, 1].flatten().detach().cpu().numpy() - pt_pred = pred_p4[pred_id == pid, 1].flatten().detach().cpu().numpy() - - eta_true = true_p4[true_id == pid, 2].flatten().detach().cpu().numpy() - eta_pred = pred_p4[pred_id == pid, 2].flatten().detach().cpu().numpy() + if data == 'delphes': + pid_to_name = pid_to_name_delphes + elif data == 'cms': + pid_to_name = pid_to_name_cms - sphi_true = true_p4[true_id == pid, 3].flatten().detach().cpu().numpy() - sphi_pred = pred_p4[pred_id == pid, 3].flatten().detach().cpu().numpy() + for i, bin_dict in enumerate(bins.items()): + true = true_p4[true_id == pid, i].flatten().detach().cpu().numpy() + pred = pred_p4[pred_id == pid, i].flatten().detach().cpu().numpy() + figure = plot_distribution(data, pid, true, pred, bin_dict[0], bin_dict[1], target, fname=outpath + '/distribution_plots/' + pid_to_name[pid] + f'_{bin_dict[0]}_distribution', legend_title=legend_title) - cphi_true = true_p4[true_id == pid, 4].flatten().detach().cpu().numpy() - cphi_pred = pred_p4[pred_id == pid, 4].flatten().detach().cpu().numpy() - e_true = true_p4[true_id == pid, 5].flatten().detach().cpu().numpy() - e_pred = pred_p4[pred_id == pid, 5].flatten().detach().cpu().numpy() - - figure = plot_distribution(pid, ch_true, ch_pred, "charge", np.linspace(0, 5, 100), target, fname=outpath + '/distribution_plots/' + pid_names[pid] + '_charge_distribution', legend_title=legend_title) - figure = plot_distribution(pid, pt_true, pt_pred, "pt", np.linspace(0, 5, 100), target, fname=outpath + '/distribution_plots/' + pid_names[pid] + '_pt_distribution', legend_title=legend_title) - figure = plot_distribution(pid, e_true, e_pred, "E", np.linspace(-1, 5, 100), target, fname=outpath + '/distribution_plots/' + pid_names[pid] + '_energy_distribution', legend_title=legend_title) - figure = plot_distribution(pid, eta_true, eta_pred, "eta", np.linspace(-5, 5, 100), target, fname=outpath + '/distribution_plots/' + pid_names[pid] + '_eta_distribution', legend_title=legend_title) - figure = plot_distribution(pid, sphi_true, sphi_pred, "sin phi", np.linspace(-2, 2, 100), target, fname=outpath + '/distribution_plots/' + pid_names[pid] + '_sphi_distribution', legend_title=legend_title) - figure = plot_distribution(pid, cphi_true, cphi_pred, "cos phi", np.linspace(-2, 2, 100), target, fname=outpath + '/distribution_plots/' + pid_names[pid] + '_cphi_distribution', legend_title=legend_title) - - -def plot_distributions_all(true_id, true_p4, pred_id, pred_p4, pf_id, cand_p4, target, epoch, outpath, legend_title=""): +def plot_distributions_all(data, true_id, true_p4, pred_id, pred_p4, pf_id, cand_p4, target, epoch, outpath, legend_title=""): """ plot distributions for the target and mlpf of a all features, merging all PIDs """ plt.style.use("default") msk = (pred_id != 0) & (true_id != 0) + if data == 'delphes': + pid_to_name = pid_to_name_delphes + elif data == 'cms': + pid_to_name = pid_to_name_cms - ch_true = true_p4[msk, 0].flatten().detach().cpu().numpy() - ch_pred = pred_p4[msk, 0].flatten().detach().cpu().numpy() - - pt_true = true_p4[msk, 1].flatten().detach().cpu().numpy() - pt_pred = pred_p4[msk, 1].flatten().detach().cpu().numpy() - - eta_true = true_p4[msk, 2].flatten().detach().cpu().numpy() - eta_pred = pred_p4[msk, 2].flatten().detach().cpu().numpy() - - sphi_true = true_p4[msk, 3].flatten().detach().cpu().numpy() - sphi_pred = pred_p4[msk, 3].flatten().detach().cpu().numpy() - - cphi_true = true_p4[msk, 4].flatten().detach().cpu().numpy() - cphi_pred = pred_p4[msk, 4].flatten().detach().cpu().numpy() + for i, bin_dict in enumerate(bins.items()): + true = true_p4[msk, i].flatten().detach().cpu().numpy() + pred = pred_p4[msk, i].flatten().detach().cpu().numpy() + figure = plot_distribution(data, -1, true, pred, bin_dict[0], bin_dict[1], target, fname=outpath + f'/distribution_plots/all_{bin_dict[0]}_distribution', legend_title=legend_title) - e_true = true_p4[msk, 5].flatten().detach().cpu().numpy() - e_pred = pred_p4[msk, 5].flatten().detach().cpu().numpy() - figure = plot_distribution(-1, ch_true, ch_pred, "charge", np.linspace(0, 5, 100), target, fname=outpath + '/distribution_plots/all_charge_distribution', legend_title=legend_title) - figure = plot_distribution(-1, pt_true, pt_pred, "pt", np.linspace(0, 5, 100), target, fname=outpath + '/distribution_plots/all_pt_distribution', legend_title=legend_title) - figure = plot_distribution(-1, e_true, e_pred, "E", np.linspace(-1, 5, 100), target, fname=outpath + '/distribution_plots/all_energy_distribution', legend_title=legend_title) - figure = plot_distribution(-1, eta_true, eta_pred, "eta", np.linspace(-5, 5, 100), target, fname=outpath + '/distribution_plots/all_eta_distribution', legend_title=legend_title) - figure = plot_distribution(-1, sphi_true, sphi_pred, "sin phi", np.linspace(-2, 2, 100), target, fname=outpath + '/distribution_plots/all_sphi_distribution', legend_title=legend_title) - figure = plot_distribution(-1, cphi_true, cphi_pred, "cos phi", np.linspace(-2, 2, 100), target, fname=outpath + '/distribution_plots/all_cphi_distribution', legend_title=legend_title) - - -def plot_particle_multiplicity(list, key, ax=None, legend_title=""): +def plot_particle_multiplicity(data, list, key, ax=None, legend_title=""): """ plot particle multiplicity for PF and mlpf """ plt.style.use(hep.style.ROOT) - pid = key_to_pid[key] + if data == 'delphes': + name_to_pid = name_to_pid_delphes + pid_to_name = pid_to_name_delphes + elif data == 'cms': + name_to_pid = name_to_pid_cms + pid_to_name = pid_to_name_cms + + pid = name_to_pid[key] if not ax: plt.figure(figsize=(4, 4)) ax = plt.axes() @@ -230,13 +243,18 @@ def plot_particle_multiplicity(list, key, ax=None, legend_title=""): ax.set_xlim(lims) ax.set_ylim(lims) plt.tight_layout() - ax.legend(frameon=False, title=legend_title + pid_names[pid]) + ax.legend(frameon=False, title=legend_title + pid_to_name[pid]) ax.set_xlabel("Truth particles / event") ax.set_ylabel("Reconstructed particles / event") plt.title("Particle multiplicity") -def draw_efficiency_fakerate(ygen, ypred, ycand, pid, var, bins, outpath, both=True, legend_title=""): +def draw_efficiency_fakerate(data, ygen, ypred, ycand, pid, var, bins, outpath, both=True, legend_title=""): + if data == 'delphes': + pid_to_name = pid_to_name_delphes + elif data == 'cms': + pid_to_name = pid_to_name_cms + var_idx = var_indices[var] msk_gen = ygen[:, 0] == pid @@ -258,7 +276,7 @@ def draw_efficiency_fakerate(ygen, ypred, ycand, pid, var, bins, outpath, both=T fig, ax1 = plt.subplots(1, 1, figsize=(8, 1 * 8)) ax2 = None - #ax1.set_title("reco efficiency for {}".format(pid_names[pid])) + # ax1.set_title("reco efficiency for {}".format(pid_to_name_delphes[pid])) ax1.errorbar( midpoints(hist_gen[1]), divide_zero(hist_cand[0], hist_gen[0]), @@ -269,7 +287,7 @@ def draw_efficiency_fakerate(ygen, ypred, ycand, pid, var, bins, outpath, both=T divide_zero(hist_pred[0], hist_gen[0]), divide_zero(np.sqrt(hist_gen[0]), hist_gen[0]) * divide_zero(hist_pred[0], hist_gen[0]), lw=0, label="MLPF", elinewidth=2, marker=".", markersize=10) - ax1.legend(frameon=False, loc=0, title=legend_title + pid_names[pid]) + ax1.legend(frameon=False, loc=0, title=legend_title + pid_to_name[pid]) ax1.set_ylim(0, 1.2) # if var=="energy": # ax1.set_xlim(0,30) @@ -288,7 +306,7 @@ def draw_efficiency_fakerate(ygen, ypred, ycand, pid, var, bins, outpath, both=T if both: # fake rate plot - #ax2.set_title("reco fake rate for {}".format(pid_names[pid])) + # ax2.set_title("reco fake rate for {}".format(pid_to_name_delphes[pid])) ax2.errorbar( midpoints(hist_cand2[1]), divide_zero(hist_cand_gen2[0], hist_cand2[0]), @@ -299,7 +317,7 @@ def draw_efficiency_fakerate(ygen, ypred, ycand, pid, var, bins, outpath, both=T divide_zero(hist_pred_gen2[0], hist_pred2[0]), divide_zero(np.sqrt(hist_pred_gen2[0]), hist_pred2[0]), lw=0, label="MLPF", elinewidth=2, marker=".", markersize=10) - ax2.legend(frameon=False, loc=0, title=legend_title + pid_names[pid]) + ax2.legend(frameon=False, loc=0, title=legend_title + pid_to_name[pid]) ax2.set_ylim(0, 1.0) # plt.yscale("log") ax2.set_xlabel(var_names[var]) @@ -311,12 +329,20 @@ def draw_efficiency_fakerate(ygen, ypred, ycand, pid, var, bins, outpath, both=T return ax1, ax2 -def plot_reso(ygen, ypred, ycand, pid, var, rng, ax=None, legend_title=""): +def plot_reso(data, ygen, ypred, ycand, pfcand, var, outpath, legend_title=""): plt.style.use(hep.style.ROOT) + if data == 'delphes': + name_to_pid = name_to_pid_delphes + pid_to_name = pid_to_name_delphes + elif data == 'cms': + name_to_pid = name_to_pid_cms + pid_to_name = pid_to_name_cms + pid = name_to_pid[pfcand] + var_idx = var_indices[var] msk = (ygen[:, 0] == pid) & (ycand[:, 0] == pid) - bins = np.linspace(-rng, rng, 100) + bins = np.linspace(-2, 2, 100) yg = ygen[msk, var_idx] yp = ypred[msk, var_idx] @@ -334,21 +360,17 @@ def plot_reso(ygen, ypred, ycand, pid, var, rng, ax=None, legend_title=""): res_dpf = np.mean(ratio_dpf), np.std(ratio_dpf) res_mlpf = np.mean(ratio_mlpf), np.std(ratio_mlpf) - if ax is None: - plt.figure(figsize=(4, 4)) - ax = plt.axes() - - #plt.title("{} resolution for {}".format(var_names_nounit[var], pid_names[pid])) + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.hist(ratio_dpf, bins=bins, histtype="step", lw=2, label="Rule-based PF\n$\mu={:.2f},\\ \sigma={:.2f}$".format(*res_dpf)) ax.hist(ratio_mlpf, bins=bins, histtype="step", lw=2, label="MLPF\n$\mu={:.2f},\\ \sigma={:.2f}$".format(*res_mlpf)) - ax.legend(frameon=False, title=legend_title + pid_names[pid]) + ax.legend(frameon=False, title=legend_title + pfcand) ax.set_xlabel("{nounit} resolution, $({bare}^\prime - {bare})/{bare}$".format(nounit=var_names_nounit[var], bare=var_names_bare[var])) ax.set_ylabel("Particles") - #plt.ylim(0, ax.get_ylim()[1]*2) ax.set_ylim(1, 1e10) ax.set_yscale("log") - - return {"dpf": res_dpf, "mlpf": res_mlpf} + plt.savefig(outpath + f"/resolution_plots/res_{pfcand}_{var}.pdf", bbox_inches="tight") + plt.tight_layout() + plt.close(fig) def plot_confusion_matrix(cm, target_names, @@ -401,13 +423,17 @@ def plot_confusion_matrix(cm, target_names, cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] cm[np.isnan(cm)] = 0.0 - fig = plt.figure(figsize=(5, 4)) + if len(target_names) > 6: + fig = plt.figure(figsize=(8, 6)) + else: + fig = plt.figure(figsize=(5, 4)) + ax = plt.axes() plt.imshow(cm, interpolation='nearest', cmap=cmap) if target == "rule-based": - plt.title(title + ' for rule-based PF') + plt.title(title + ' for rule-based PF', fontsize=20) else: - plt.title(title + ' for MLPF at epoch ' + str(epoch)) + plt.title(title + ' for MLPF at epoch ' + str(epoch), fontsize=20) plt.colorbar() @@ -433,10 +459,9 @@ def plot_confusion_matrix(cm, target_names, plt.xlabel('Predicted label') # plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass)) plt.tight_layout() - plt.savefig(outpath + save_as + '.png') plt.savefig(outpath + save_as + '.pdf') plt.close(fig) - torch.save(cm, outpath + save_as + '.pt') + # torch.save(cm, outpath + save_as + '.pt') return fig, ax diff --git a/mlpf/pytorch_delphes/evaluate.py b/mlpf/pyg/delphes_utils.py similarity index 93% rename from mlpf/pytorch_delphes/evaluate.py rename to mlpf/pyg/delphes_utils.py index da5fcc9c7..4cbf3cf5f 100644 --- a/mlpf/pytorch_delphes/evaluate.py +++ b/mlpf/pyg/delphes_utils.py @@ -1,10 +1,25 @@ -from pytorch_delphes.utils_plots import plot_confusion_matrix, plot_distributions_pid, plot_distributions_all, plot_particle_multiplicity, draw_efficiency_fakerate, plot_reso +from pyg.delphes_plots import plot_confusion_matrix +from pyg.delphes_plots import plot_distributions_pid, plot_distributions_all, plot_particle_multiplicity +from pyg.delphes_plots import draw_efficiency_fakerate, plot_reso +from pyg.delphes_plots import pid_to_name_delphes, name_to_pid_delphes, pid_to_name_cms, name_to_pid_cms +from pyg.utils import define_regions, batch_event_into_regions +from pyg.utils import one_hot_embedding, target_p4 +from pyg.cms_utils import CLASS_NAMES_CMS +from pyg.cms_plots import plot_numPFelements, plot_met, plot_sum_energy, plot_sum_pt, plot_energy_res, plot_eta_res, plot_multiplicity +from pyg.cms_plots import plot_dist, plot_cm, plot_eff_and_fake_rate, distribution_icls + import torch +import torch_geometric from torch_geometric.data import Batch +from torch_geometric.loader import DataLoader, DataListLoader + +import glob import mplhep as hep import matplotlib import matplotlib.pyplot as plt import pickle as pkl +import os +import os.path as osp import math import time import tqdm @@ -13,10 +28,9 @@ import sklearn import matplotlib matplotlib.use("Agg") -matplotlib.rcParams['pdf.fonttype'] = 42 -def make_predictions(model, multi_gpu, test_loader, outpath, device, epoch): +def make_predictions_delphes(model, multi_gpu, test_loader, outpath, device, epoch): print('Making predictions...') t0 = time.time() @@ -35,15 +49,11 @@ def make_predictions(model, multi_gpu, test_loader, outpath, device, epoch): ti = time.time() - pred, target = model(X) - - gen_ids_one_hot = target['ygen_id'] - gen_p4 = target['ygen'] - cand_ids_one_hot = target['ycand_id'] - cand_p4 = target['ycand'] + pred_ids_one_hot, pred_p4 = model(X) - pred_ids_one_hot = pred[:, :6] - pred_p4 = pred[:, 6:] + gen_p4 = X.ygen.detach().to('cpu') + cand_ids_one_hot = one_hot_embedding(X.ycand_id.detach().to('cpu'), num_classes) + cand_p4 = X.ycand.detach().to('cpu') tf = time.time() if i != 0: @@ -136,7 +146,7 @@ def make_predictions(model, multi_gpu, test_loader, outpath, device, epoch): torch.save(predictions, outpath + '/predictions.pt') -def make_plots(model, test_loader, outpath, target, device, epoch, tag): +def make_plots_delphes(model, test_loader, outpath, target, device, epoch, tag): print('Making plots...') t0 = time.time() diff --git a/mlpf/pyg/environment.yml b/mlpf/pyg/environment.yml new file mode 100644 index 000000000..da9b68d7e --- /dev/null +++ b/mlpf/pyg/environment.yml @@ -0,0 +1,20 @@ +name: mlpf +channels: + - pyg + - pytorch + - conda-forge +dependencies: + - python=3.9 + - numpy + - pandas + - scikit-learn + - matplotlib + - pytorch + - pyg + - cpuonly + - mplhep + - tqdm + - autopep8 + - pip + - pip: + - jupyter-book diff --git a/mlpf/pyg/evaluate.py b/mlpf/pyg/evaluate.py new file mode 100644 index 000000000..656af613e --- /dev/null +++ b/mlpf/pyg/evaluate.py @@ -0,0 +1,265 @@ +from pyg.utils import define_regions, batch_event_into_regions +from pyg.utils import one_hot_embedding, target_p4 +from pyg.cms_utils import CLASS_NAMES_CMS +from pyg.cms_plots import plot_numPFelements, plot_met, plot_sum_energy, plot_sum_pt, plot_energy_res, plot_eta_res, plot_multiplicity +from pyg.cms_plots import plot_dist, plot_cm, plot_eff_and_fake_rate, distribution_icls + +import torch +import torch_geometric +from torch_geometric.data import Batch +from torch_geometric.loader import DataLoader, DataListLoader + +import glob +import sys +import mplhep as hep +import matplotlib +import matplotlib.pyplot as plt +import pickle as pkl +import os +import os.path as osp +import math +import time +import tqdm +import numpy as np +import pandas as pd +import sklearn +import matplotlib +matplotlib.use("Agg") + + +def make_predictions(rank, model, file_loader, batch_size, num_classes, PATH): + """ + Runs inference on the qcd test dataset to evaluate performance. Saves the predictions as .pt files. + Each .pt file will contain a dict() object with keys X, Y_pid, Y_p4; contains all the necessary event information to make plots. + + Args + rank: int representing the gpu device id, or str=='cpu' (both work, trust me) + model: pytorch model + file_loader: a pytorch Dataloader that loads .pt files for training when you invoke the get() method + """ + + ti = time.time() + + t0, tf = time.time(), 0 + + ibatch = 0 + for num, file in enumerate(file_loader): + print(f'Time to load file {num+1}/{len(file_loader)} on rank {rank} is {round(time.time() - t0, 3)}s') + tf = tf + (time.time() - t0) + + file = [x for t in file for x in t] # unpack the list of tuples to a list + + loader = torch_geometric.loader.DataLoader(file, batch_size=batch_size) + + t = 0 + for i, batch in enumerate(loader): + + t0 = time.time() + pred_ids_one_hot, pred_p4 = model(batch.to(rank)) + t1 = time.time() + # print(f'batch {i}/{len(loader)}, forward pass on rank {rank} = {round(t1 - t0, 3)}s, for batch with {batch.num_nodes} nodes') + t = t + (t1 - t0) + + # zero pad the events to use the same plotting scripts as the tf pipeline + padded_num_elem_size = 6400 + + # must zero pad each event individually so must unpack the batches + pred_ids_one_hot_list = [] + pred_p4_list = [] + for z in range(batch_size): + pred_ids_one_hot_list.append(pred_ids_one_hot[batch.batch == z]) + pred_p4_list.append(pred_p4[batch.batch == z]) + + X = [] + Y_pid = [] + Y_p4 = [] + batch_list = batch.to_data_list() + for j, event in enumerate(batch_list): + vars = {'X': event.x.detach().to('cpu'), + 'ygen': event.ygen.detach().to('cpu'), + 'ycand': event.ycand.detach().to('cpu'), + 'pred_p4': pred_p4_list[j].detach().to('cpu'), + 'gen_ids_one_hot': one_hot_embedding(event.ygen_id.detach().to('cpu'), num_classes), + 'cand_ids_one_hot': one_hot_embedding(event.ycand_id.detach().to('cpu'), num_classes), + 'pred_ids_one_hot': pred_ids_one_hot_list[j].detach().to('cpu') + } + + vars_padded = {} + for key, var in vars.items(): + var = var[:padded_num_elem_size] + var = torch.nn.functional.pad(var, (0, 0, 0, padded_num_elem_size - var.shape[0]), mode='constant', value=0).unsqueeze(0) + vars_padded[key] = var + + X.append(vars_padded['X']) + Y_pid.append(torch.cat([vars_padded['gen_ids_one_hot'], vars_padded['cand_ids_one_hot'], vars_padded['pred_ids_one_hot']]).unsqueeze(0)) + Y_p4.append(torch.cat([vars_padded['ygen'], vars_padded['ycand'], vars_padded['pred_p4']]).unsqueeze(0)) + + outfile = f'{PATH}/predictions/pred_batch{ibatch}_{rank}.pt' + print(f'saving predictions at {outfile}') + torch.save({'X': torch.cat(X), # [batch_size, 6400, 41] + 'Y_pid': torch.cat(Y_pid), # [batch_size, 3, 6400, 41] + 'Y_p4': torch.cat(Y_p4)}, # [batch_size, 3, 6400, 41] + outfile) + + ibatch += 1 + + # if i == 2: + # break + # if num == 2: + # break + + print(f'Average inference time per batch on rank {rank} is {round((t / len(loader)), 3)}s') + + t0 = time.time() + + print(f'Average time to load a file on rank {rank} is {round((tf / len(file_loader)), 3)}s') + + print(f'Time taken to make predictions on rank {rank} is: {round(((time.time() - ti) / 60), 2)} min') + + +def postprocess_predictions(pred_path): + """ + Loads all the predictions .pt files and combines them after some necessary processing to make plots. + Saves the processed predictions. + """ + + print('--> Concatenating all predictions...') + t0 = time.time() + + Xs = [] + Y_pids = [] + Y_p4s = [] + + PATH = list(glob.glob(f'{pred_path}/pred_batch*.pt')) + for i, fi in enumerate(PATH): + print(f'loading prediction # {i+1}/{len(PATH)}') + dd = torch.load(fi) + Xs.append(dd["X"]) + Y_pids.append(dd["Y_pid"]) + Y_p4s.append(dd["Y_p4"]) + + Xs = torch.cat(Xs).numpy() + Y_pids = torch.cat(Y_pids) + Y_p4s = torch.cat(Y_p4s) + + # reformat the loaded files for convenient plotting + yvals = {} + yvals[f'gen_cls'] = Y_pids[:, 0, :, :].numpy() + yvals[f'cand_cls'] = Y_pids[:, 1, :, :].numpy() + yvals[f'pred_cls'] = Y_pids[:, 2, :, :].numpy() + + for feat, key in enumerate(target_p4): + yvals[f'gen_{key}'] = Y_p4s[:, 0, :, feat].unsqueeze(-1).numpy() + yvals[f'cand_{key}'] = Y_p4s[:, 1, :, feat].unsqueeze(-1).numpy() + yvals[f'pred_{key}'] = Y_p4s[:, 2, :, feat].unsqueeze(-1).numpy() + + print(f'Time taken to concatenate all predictions is: {round(((time.time() - t0) / 60), 2)} min') + + print('--> Further processing for convenient plotting') + t0 = time.time() + + def flatten(arr): + return arr.reshape(-1, arr.shape[-1]) + + X_f = flatten(Xs) + + msk_X_f = X_f[:, 0] != 0 + + for val in ["gen", "cand", "pred"]: + yvals[f"{val}_phi"] = np.arctan2(yvals[f"{val}_sin_phi"], yvals[f"{val}_cos_phi"]) + yvals[f"{val}_cls_id"] = np.argmax(yvals[f"{val}_cls"], axis=-1).reshape(yvals[f"{val}_cls"].shape[0], yvals[f"{val}_cls"].shape[1], 1) # cz for some reason keepdims doesn't work + + yvals[f"{val}_px"] = np.sin(yvals[f"{val}_phi"]) * yvals[f"{val}_pt"] + yvals[f"{val}_py"] = np.cos(yvals[f"{val}_phi"]) * yvals[f"{val}_pt"] + + yvals_f = {k: flatten(v) for k, v in yvals.items()} + + # remove the last dim + for k in yvals_f.keys(): + if yvals_f[k].shape[-1] == 1: + yvals_f[k] = yvals_f[k][..., -1] + + print(f'Time taken to process the predictions is: {round(((time.time() - t0) / 60), 2)} min') + + print(f'-->Saving the processed events') + t0 = time.time() + torch.save(Xs, f'{pred_path}/post_processed_Xs.pt', pickle_protocol=4) + torch.save(X_f, f'{pred_path}/post_processed_X_f.pt', pickle_protocol=4) + torch.save(msk_X_f, f'{pred_path}/post_processed_msk_X_f.pt', pickle_protocol=4) + torch.save(yvals, f'{pred_path}/post_processed_yvals.pt', pickle_protocol=4) + torch.save(yvals_f, f'{pred_path}/post_processed_yvals_f.pt', pickle_protocol=4) + print(f'Time taken to save the predictions is: {round(((time.time() - t0) / 60), 2)} min') + + return Xs, X_f, msk_X_f, yvals, yvals_f + + +def make_plots_cms(pred_path, plot_path, sample): + + t0 = time.time() + + print(f'--> Loading the processed predictions') + X = torch.load(f'{pred_path}/post_processed_Xs.pt') + X_f = torch.load(f'{pred_path}/post_processed_X_f.pt') + msk_X_f = torch.load(f'{pred_path}/post_processed_msk_X_f.pt') + yvals = torch.load(f'{pred_path}/post_processed_yvals.pt') + yvals_f = torch.load(f'{pred_path}/post_processed_yvals_f.pt') + print(f'Time taken to load the processed predictions is: {round(((time.time() - t0) / 60), 2)} min') + + print(f'--> Making plots using {len(X)} events...') + + # plot distributions + print('plot_dist...') + plot_dist(yvals_f, 'pt', np.linspace(0, 200, 61), r'$p_T$', plot_path, sample) + plot_dist(yvals_f, 'energy', np.linspace(0, 2000, 61), r'$E$', plot_path, sample) + plot_dist(yvals_f, 'eta', np.linspace(-6, 6, 61), r'$\eta$', plot_path, sample) + + # plot cm + print('plot_cm...') + plot_cm(yvals_f, msk_X_f, 'MLPF', plot_path) + plot_cm(yvals_f, msk_X_f, 'PF', plot_path) + + # plot eff_and_fake_rate + print('plot_eff_and_fake_rate...') + plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=1, ivar=4, ielem=1, bins=np.logspace(-1, 3, 41), log=True) + plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=1, ivar=3, ielem=1, bins=np.linspace(-4, 4, 41), log=False, xlabel="PFElement $\eta$") + plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=2, ivar=4, ielem=5, bins=np.logspace(-1, 3, 41), log=True) + plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=2, ivar=3, ielem=5, bins=np.linspace(-5, 5, 41), log=False, xlabel="PFElement $\eta$") + plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=5, ivar=4, ielem=4, bins=np.logspace(-1, 2, 41), log=True) + plot_eff_and_fake_rate(X_f, yvals_f, plot_path, sample, icls=5, ivar=3, ielem=4, bins=np.linspace(-5, 5, 41), log=False, xlabel="PFElement $\eta$") + + # distribution_icls + print('distribution_icls...') + distribution_icls(yvals_f, plot_path) + + print('plot_numPFelements...') + plot_numPFelements(X, plot_path, sample) + print('plot_met...') + plot_met(X, yvals, plot_path, sample) + print('plot_sum_energy...') + plot_sum_energy(X, yvals, plot_path, sample) + print('plot_sum_pt...') + plot_sum_pt(X, yvals, plot_path, sample) + print('plot_multiplicity...') + plot_multiplicity(X, yvals, plot_path, sample) + + # for energy resolution plotting purposes, initialize pid -> (ylim, bins) dictionary + print('plot_energy_res...') + dic = {1: (1e9, np.linspace(-2, 15, 100)), + 2: (1e7, np.linspace(-2, 15, 100)), + 3: (1e7, np.linspace(-2, 40, 100)), + 4: (1e7, np.linspace(-2, 30, 100)), + 5: (1e7, np.linspace(-2, 10, 100)), + 6: (1e4, np.linspace(-1, 1, 100)), + 7: (1e4, np.linspace(-0.1, 0.1, 100)) + } + for pid, tuple in dic.items(): + plot_energy_res(X, yvals_f, pid, tuple[1], tuple[0], plot_path, sample) + + # for eta resolution plotting purposes, initialize pid -> (ylim) dictionary + print('plot_eta_res...') + dic = {1: 1e10, + 2: 1e8} + for pid, ylim in dic.items(): + plot_eta_res(X, yvals_f, pid, ylim, plot_path, sample) + + print(f'Time taken to make plots is: {round(((time.time() - t0) / 60), 2)} min') diff --git a/mlpf/pyg/get_data_cms.sh b/mlpf/pyg/get_data_cms.sh new file mode 100755 index 000000000..cd057293b --- /dev/null +++ b/mlpf/pyg/get_data_cms.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +set -e + +# make data/ directory to hold the cms/ directory of datafiles under particleflow/ +mkdir -p ../../data + +# get the cms data +rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/cms . + +# unzip the data +for sample in cms/* ; do + echo $sample + mv $sample/root $sample/processed + cd $sample/raw/ + bzip2 -d * + cd ../../../ +done + +# process the cms data +for sample in cms/* ; do + echo $sample + #generate pytorch data files from pkl files + python3 PFGraphDataset.py --data cms --dataset $sample \ + --processed_dir $sample/processed --num-files-merge 1 --num-proc 1 +done + +mv cms ../../data/ diff --git a/mlpf/pyg/get_data_delphes.sh b/mlpf/pyg/get_data_delphes.sh new file mode 100755 index 000000000..e8fb39f37 --- /dev/null +++ b/mlpf/pyg/get_data_delphes.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +set -e + +# make data/ directory to hold the delphes/ directory of datafiles under particleflow/ +mkdir -p ../../data + +# make delphes directories +mkdir -p delphes/pythia8_ttbar/raw +mkdir -p delphes/pythia8_ttbar/processed + +mkdir -p delphes/pythia8_qcd/raw +mkdir -p delphes/pythia8_qcd/processed + +# get the ttbar data for training +cd delphes/pythia8_ttbar/raw/ +for j in {0..9} +do + for i in {0..49} + do + wget --no-check-certificate -nc https://zenodo.org/record/4559324/files/tev14_pythia8_ttbar_"$j"_"$i".pkl.bz2 + done +done +bzip2 -d * + +# get the qcd data for extra validation +cd ../../pythia8_qcd/raw/ +for i in {0..49} +do + wget --no-check-certificate -nc https://zenodo.org/record/4559324/files/tev14_pythia8_qcd_10_"$i".pkl.bz2 +done +bzip2 -d * + +# get back in the pytorch directory +cd ../../../../ + +#generate pytorch data files from pkl files +python3 PFGraphDataset.py --data delphes --dataset delphes/pythia8_ttbar \ + --processed_dir delphes/pythia8_ttbar/processed --num-files-merge 1 --num-proc 1 + +#generate pytorch data files from pkl files +python3 PFGraphDataset.py --data delphes --dataset delphes/pythia8_qcd \ + --processed_dir delphes/pythia8_qcd/processed --num-files-merge 1 --num-proc 1 + +mv delphes ../../data/ diff --git a/mlpf/pyg/model.py b/mlpf/pyg/model.py new file mode 100644 index 000000000..58106c62a --- /dev/null +++ b/mlpf/pyg/model.py @@ -0,0 +1,295 @@ +from torch_geometric.nn.conv import MessagePassing +from typing import Optional +import scipy.spatial +import pickle as pkl +import os.path as osp +import os +import sys +from glob import glob + +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import Linear +from torch_scatter import scatter +from torch_geometric.nn.conv import MessagePassing, GCNConv, GraphConv, DynamicEdgeConv +from torch_geometric.utils import to_dense_adj +import torch.nn.functional as F + +from typing import Optional, Union +from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor + + +class MLPF(nn.Module): + """ + GNN model based on Gravnet... + Forward pass returns + preds: tensor of predictions containing a concatenated representation of the pids and p4 + target: dict() object containing gen and cand target information + """ + + def __init__(self, + input_dim=12, num_classes=6, output_dim_p4=6, + embedding_dim=32, hidden_dim1=126, hidden_dim2=256, + num_convs=3, space_dim=4, propagate_dim=8, k=4): + super(MLPF, self).__init__() + + # self.act = nn.ReLU + self.act = nn.ELU + + # (1) embedding + self.nn1 = nn.Sequential( + nn.Linear(input_dim, hidden_dim1), + self.act(), + nn.Linear(hidden_dim1, hidden_dim1), + self.act(), + nn.Linear(hidden_dim1, hidden_dim1), + self.act(), + nn.Linear(hidden_dim1, embedding_dim), + ) + + self.conv = nn.ModuleList() + for i in range(num_convs): + self.conv.append(GravNetConv_MLPF(embedding_dim, embedding_dim, space_dim, propagate_dim, k)) + # self.conv.append(GravNetConv_cmspepr(embedding_dim, embedding_dim, space_dim, propagate_dim, k)) + # self.conv.append(EdgeConvBlock(embedding_dim, embedding_dim, k)) + + # (3) DNN layer: classifiying pid + self.nn2 = nn.Sequential( + nn.Linear(input_dim + embedding_dim, hidden_dim2), + self.act(), + nn.Linear(hidden_dim2, hidden_dim2), + self.act(), + nn.Linear(hidden_dim2, hidden_dim2), + self.act(), + nn.Linear(hidden_dim2, num_classes), + ) + + # (4) DNN layer: regressing p4 + self.nn3 = nn.Sequential( + nn.Linear(input_dim + embedding_dim + num_classes, hidden_dim2), + self.act(), + nn.Linear(hidden_dim2, hidden_dim2), + self.act(), + nn.Linear(hidden_dim2, hidden_dim2), + self.act(), + nn.Linear(hidden_dim2, output_dim_p4), + ) + + def forward(self, batch): + + # unfold the Batch object + input = batch.x + + # embed the inputs + embedding = self.nn1(input) + + # perform a series of graph convolutions + for num, conv in enumerate(self.conv): + embedding = conv(embedding, batch.batch) + + # predict the pid's + preds_id = self.nn2(torch.cat([input, embedding], axis=-1)) + + # predict the p4's + preds_p4 = self.nn3(torch.cat([input, embedding, preds_id], axis=-1)) + + return preds_id, preds_p4 + + +try: + from torch_cluster import knn +except ImportError: + knn = None + +# propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + + +class GravNetConv_MLPF(MessagePassing): + """ + Copied from pytorch_geometric source code, with the following edits + a. used reduce='sum' instead of reduce='mean' in the message passing + b. removed skip connection + """ + + def __init__(self, in_channels: int, out_channels: int, + space_dimensions: int, propagate_dimensions: int, k: int, + num_workers: int = 1, **kwargs): + super().__init__(flow='source_to_target', **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.k = k + self.num_workers = num_workers + + self.lin_p = Linear(in_channels, propagate_dimensions) + self.lin_s = Linear(in_channels, space_dimensions) + self.lin_out = Linear(propagate_dimensions, out_channels) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_s.reset_parameters() + self.lin_p.reset_parameters() + self.lin_out.reset_parameters() + + def forward( + self, x: Union[Tensor, PairTensor], + batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: + + is_bipartite: bool = True + if isinstance(x, Tensor): + x: PairTensor = (x, x) + is_bipartite = False + + if x[0].dim() != 2: + raise ValueError("Static graphs not supported in 'GravNetConv'") + + b: PairOptTensor = (None, None) + if isinstance(batch, Tensor): + b = (batch, batch) + elif isinstance(batch, tuple): + assert batch is not None + b = (batch[0], batch[1]) + + # embed the inputs before message passing + msg_activations = self.lin_p(x[0]) + + # transform to the space dimension to build the graph + s_l: Tensor = self.lin_s(x[0]) + s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l + + # add error message when trying to preform knn without enough neighbors in the region + if (torch.unique(b[0], return_counts=True)[1] < self.k).sum() != 0: + raise RuntimeError(f'Not enough elements in a region to perform the k-nearest neighbors. Current k-value={self.k}') + + edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) + # edge_index = knn_graph(s_l, self.k, b[0]) # cmspepr + + edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) + edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread + + # message passing + out = self.propagate(edge_index, x=(msg_activations, None), + edge_weight=edge_weight, + size=(s_l.size(0), s_r.size(0))) + + return self.lin_out(out) + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return x_j * edge_weight.unsqueeze(1) + + def aggregate(self, inputs: Tensor, index: Tensor, + dim_size: Optional[int] = None) -> Tensor: + out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, + reduce='sum') + return out_mean + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, k={self.k})') + + +# try: +# from torch_cmspepr import knn_graph +# except ImportError: +# knn_graph = None +# +# +# class GravNetConv_cmspepr(MessagePassing): +# """ +# Gravnet implementation that uses an optimized version of knn. +# Copied from https://github.com/cms-pepr/pytorch_cmspepr/tree/main/torch_cmspepr +# """ +# +# def __init__(self, in_channels: int, out_channels: int, +# space_dimensions: int, propagate_dimensions: int, k: int, +# num_workers: int = 1, **kwargs): +# super(GravNetConv_cmspepr, self).__init__(flow='target_to_source', **kwargs) +# +# if knn_graph is None: +# raise ImportError('`GravNetConv` requires `torch-cluster`.') +# +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.k = k +# self.num_workers = num_workers +# +# self.lin_s = Linear(in_channels, space_dimensions) +# self.lin_h = Linear(in_channels, propagate_dimensions) +# self.lin = Linear(in_channels + 2 * propagate_dimensions, out_channels) +# +# self.reset_parameters() +# +# def reset_parameters(self): +# self.lin_s.reset_parameters() +# self.lin_h.reset_parameters() +# self.lin.reset_parameters() +# +# def forward( +# self, x: Tensor, +# batch: OptTensor = None) -> Tensor: +# """""" +# +# assert x.dim() == 2, 'Static graphs not supported in `GravNetConv`.' +# +# b: OptTensor = None +# if isinstance(batch, Tensor): +# b = batch +# +# h_l: Tensor = self.lin_h(x) +# +# s_l: Tensor = self.lin_s(x) +# +# edge_index = knn_graph(s_l, self.k, b) +# +# edge_weight = (s_l[edge_index[1]] - s_l[edge_index[0]]).pow(2).sum(-1) +# edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread +# +# # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) +# out = self.propagate(edge_index, x=(h_l, None), +# edge_weight=edge_weight, +# size=(s_l.size(0), s_l.size(0))) +# +# return self.lin(torch.cat([out, x], dim=-1)) +# +# def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: +# return x_j * edge_weight.unsqueeze(1) +# +# def aggregate(self, inputs: Tensor, index: Tensor, +# dim_size: Optional[int] = None) -> Tensor: +# out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, +# reduce='mean') +# out_max = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, +# reduce='max') +# return torch.cat([out_mean, out_max], dim=-1) +# +# def __repr__(self): +# return '{}({}, {}, k={})'.format(self.__class__.__name__, +# self.in_channels, self.out_channels, +# self.k) +# +# +# class EdgeConvBlock(nn.Module): +# """ +# EdgeConv implementation as an alternative to GravnetConv. +# """ +# +# def __init__(self, in_size, layer_size, k): +# super(EdgeConvBlock, self).__init__() +# +# layers = [] +# +# layers.append(nn.Linear(in_size * 2, layer_size)) +# layers.append(nn.BatchNorm1d(layer_size)) +# layers.append(nn.ReLU()) +# +# # for i in range(2): +# # layers.append(nn.Linear(layer_size, layer_size)) +# # layers.append(nn.BatchNorm1d(layer_size)) +# # layers.append(nn.ReLU()) +# +# self.edge_conv = DynamicEdgeConv(nn.Sequential(*layers), k=k, aggr="mean") +# +# def forward(self, x, batch): +# return self.edge_conv(x, batch) diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py new file mode 100644 index 000000000..83171189d --- /dev/null +++ b/mlpf/pyg/training.py @@ -0,0 +1,279 @@ +from pyg import make_plot_from_lists +from pyg.utils import define_regions, batch_event_into_regions, one_hot_embedding +from pyg.delphes_plots import plot_confusion_matrix +from pyg.cms_utils import CLASS_NAMES_CMS + +import torch +import torch_geometric +from torch_geometric.utils import to_dense_adj, dense_to_sparse +from torch_geometric.data import Data, Batch +from torch_geometric.loader import DataLoader, DataListLoader + +import mplhep as hep +import matplotlib.pyplot as plt +import json +import os +import pickle as pkl +import math +import time +import tqdm +import numpy as np +import pandas as pd +import sklearn +import sklearn.metrics +import matplotlib +matplotlib.use("Agg") + +# Ignore divide by 0 errors +np.seterr(divide='ignore', invalid='ignore') + + +def compute_weights(rank, target_ids, num_classes): + """ + computes necessary weights to accomodate class imbalance in the loss function + """ + + vs, cs = torch.unique(target_ids, return_counts=True) + weights = torch.zeros(num_classes).to(device=rank) + for k, v in zip(vs, cs): + weights[k] = 1.0 / math.sqrt(float(v)) + # weights[2] = weights[2] * 3 # emphasize nhadrons + return weights + + +@torch.no_grad() +def validation_run(rank, model, train_loader, valid_loader, batch_size, + alpha, target_type, num_classes, outpath): + with torch.no_grad(): + optimizer = None + ret = train(rank, model, train_loader, valid_loader, batch_size, + optimizer, alpha, target_type, num_classes, outpath) + return ret + + +def train(rank, model, train_loader, valid_loader, batch_size, + optimizer, alpha, target_type, num_classes, outpath): + """ + A training/validation run over a given epoch that gets called in the training_loop() function. + When optimizer is set to None, it freezes the model for a validation_run. + """ + + is_train = not (optimizer is None) + + if is_train: + print(f'---->Initiating a training run on rank {rank}') + model.train() + file_loader = train_loader + else: + print(f'---->Initiating a validation run rank {rank}') + model.eval() + file_loader = valid_loader + + # initialize loss counters + losses_clf, losses_reg, losses_tot = 0, 0, 0 + + # setup confusion matrix + conf_matrix = np.zeros((num_classes, num_classes)) + + t0, tf = time.time(), 0 + for num, file in enumerate(file_loader): + print(f'Time to load file {num+1}/{len(file_loader)} on rank {rank} is {round(time.time() - t0, 3)}s') + tf = tf + (time.time() - t0) + + file = [x for t in file for x in t] # unpack the list of tuples to a list + + loader = torch_geometric.loader.DataLoader(file, batch_size=batch_size) + + t = 0 + for i, batch in enumerate(loader): + + # run forward pass + t0 = time.time() + pred_ids_one_hot, pred_p4 = model(batch.to(rank)) + t1 = time.time() + # print(f'batch {i}/{len(loader)}, forward pass on rank {rank} = {round(t1 - t0, 3)}s, for batch with {batch.num_nodes} nodes') + t = t + (t1 - t0) + + # define the target + if target_type == 'gen': + target_p4 = batch.ygen + target_ids = batch.ygen_id + elif target_type == 'cand': + target_p4 = batch.ycand + target_ids = batch.ycand_id + + # revert one hot encoding for the predictions + pred_ids = torch.argmax(pred_ids_one_hot, axis=1) + + # define some useful masks + msk = ((pred_ids != 0) & (target_ids != 0)) + msk2 = ((pred_ids != 0) & (pred_ids == target_ids)) + + # compute the loss + weights = compute_weights(rank, target_ids, num_classes) # to accomodate class imbalance + loss_clf = torch.nn.functional.cross_entropy(pred_ids_one_hot, target_ids, weight=weights) # for classifying PID + loss_reg = torch.nn.functional.mse_loss(pred_p4[msk2], target_p4[msk2]) # for regressing p4 # TODO: add mse weights for scales to match? huber? + + loss_tot = loss_clf + (alpha * loss_reg) + + if is_train: + for param in model.parameters(): # better than calling optimizer.zero_grad() according to https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html + param.grad = None + loss_tot.backward() + optimizer.step() + + losses_clf = losses_clf + loss_clf.detach() + losses_reg = losses_reg + loss_reg.detach() + losses_tot = losses_tot + loss_tot.detach() + + conf_matrix += sklearn.metrics.confusion_matrix(target_ids.detach().cpu(), pred_ids.detach().cpu(), labels=range(num_classes)) + + # if i == 2: + # break + # if num == 2: + # break + + print(f'Average inference time per batch on rank {rank} is {round((t / len(loader)), 3)}s') + + t0 = time.time() + + print(f'Average time to load a file on rank {rank} is {round((tf / len(file_loader)), 3)}s') + + losses_clf = losses_clf / (len(loader) * len(file_loader)) + losses_reg = losses_reg / (len(loader) * len(file_loader)) + losses_tot = losses_tot / (len(loader) * len(file_loader)) + + losses = {'losses_clf': losses_clf.cpu().item(), 'losses_reg': losses_reg.cpu().item(), 'losses_tot': losses_tot.cpu().item()} + + conf_matrix = conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis] + + return losses, conf_matrix + + +def training_loop(rank, data, model, train_loader, valid_loader, + batch_size, n_epochs, patience, + optimizer, alpha, target, num_classes, outpath): + """ + Main function to perform training. Will call the train() and validation_run() functions every epoch. + + Args: + rank: int representing the gpu device id, or str=='cpu' (both work, trust me) + data: data sepecification ('cms' or 'delphes') + model: a pytorch model wrapped by DistributedDataParallel (DDP) + dataset: a PFGraphDataset object + train_loader: a pytorch Dataloader that loads .pt files for training when you invoke the get() method + valid_loader: a pytorch Dataloader that loads .pt files for validation when you invoke the get() method + patience: number of stale epochs allowed before stopping the training + optimizer: optimizer to use for training (by default: Adam) + alpha: the hyperparameter controlling the classification vs regression task balance + target: 'gen' or 'cand' training + num_classes: number of particle candidate classes to predict (6 for delphes, 9 for cms) + outpath: path to store the model weights and training plots + """ + + t0_initial = time.time() + + losses_clf_train, losses_reg_train, losses_tot_train = [], [], [] + losses_clf_valid, losses_reg_valid, losses_tot_valid = [], [], [] + + best_val_loss = 99999.9 + stale_epochs = 0 + + for epoch in range(n_epochs): + t0 = time.time() + + if stale_epochs > patience: + print("breaking due to stale epochs") + break + + # training step + model.train() + losses, conf_matrix_train = train(rank, model, train_loader, valid_loader, + batch_size, optimizer, alpha, target, num_classes, outpath) + + losses_clf_train.append(losses['losses_clf']) + losses_reg_train.append(losses['losses_reg']) + losses_tot_train.append(losses['losses_tot']) + + # validation step + model.eval() + losses, conf_matrix_val = validation_run(rank, model, train_loader, valid_loader, + batch_size, alpha, target, num_classes, outpath) + + losses_clf_valid.append(losses['losses_clf']) + losses_reg_valid.append(losses['losses_reg']) + losses_tot_valid.append(losses['losses_tot']) + + # early-stopping + if losses['losses_tot'] < best_val_loss: + best_val_loss = losses['losses_tot'] + stale_epochs = 0 + + try: + state_dict = model.module.state_dict() + except AttributeError: + state_dict = model.state_dict() + torch.save(state_dict, f'{outpath}/best_epoch_weights.pth') + + with open(f'{outpath}/best_epoch.json', 'w') as fp: # dump best epoch + json.dump({'best_epoch': epoch}, fp) + else: + stale_epochs += 1 + + t1 = time.time() + + epochs_remaining = n_epochs - (epoch + 1) + time_per_epoch = (t1 - t0_initial) / (epoch + 1) + eta = epochs_remaining * time_per_epoch / 60 + + print(f"Rank {rank}: epoch={epoch + 1} / {n_epochs} train_loss={round(losses_tot_train[epoch], 4)} valid_loss={round(losses_tot_valid[epoch], 4)} stale={stale_epochs} time={round((t1-t0)/60, 2)}m eta={round(eta, 1)}m") + + # save the model's weights + try: + state_dict = model.module.state_dict() + except AttributeError: + state_dict = model.state_dict() + torch.save(state_dict, f'{outpath}/epoch_{epoch}_weights.pth') + + # create directory to hold training plots + if not os.path.exists(outpath + '/training_plots/'): + os.makedirs(outpath + '/training_plots/') + + # make confusion matrix plots + cm_path = outpath + '/training_plots/confusion_matrix_plots/' + if not os.path.exists(cm_path): + os.makedirs(cm_path) + + if data == 'delphes': + target_names = ["none", "ch.had", "n.had", "g", "el", "mu"] + elif data == 'cms': + target_names = CLASS_NAMES_CMS + + plot_confusion_matrix(conf_matrix_train, target_names, epoch + 1, cm_path, f'epoch_{str(epoch)}_cmTrain') + plot_confusion_matrix(conf_matrix_val, target_names, epoch + 1, cm_path, f'epoch_{str(epoch)}_cmValid') + + # make loss plots + make_plot_from_lists('Classification loss', + 'Epochs', 'Loss', 'loss_clf', + [losses_clf_train, losses_clf_valid], + ['training', 'validation'], + ['clf_losses_train', 'clf_losses_valid'], + outpath + '/training_plots/losses/' + ) + make_plot_from_lists('Regression loss', + 'Epochs', 'Loss', 'loss_reg', + [losses_reg_train, losses_reg_valid], + ['training', 'validation'], + ['reg_losses_train', 'reg_losses_valid'], + outpath + '/training_plots/losses/' + ) + make_plot_from_lists('Total loss', + 'Epochs', 'Loss', 'loss_tot', + [losses_tot_train, losses_tot_valid], + ['training', 'validation'], + ['tot_losses_train', 'tot_losses_valid'], + outpath + '/training_plots/losses/' + ) + + print('----------------------------------------------------------') + print(f'Done with training. Total training time on rank {rank} is {round((time.time() - t0_initial)/60,3)}min') diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py new file mode 100644 index 000000000..29cdd964e --- /dev/null +++ b/mlpf/pyg/utils.py @@ -0,0 +1,320 @@ +import json +import shutil +import os.path as osp +import sys +from glob import glob + +import torch_geometric +from torch_geometric.loader import DataLoader, DataListLoader +from torch_geometric.data import Data, Batch +from torch.utils.data.dataloader import default_collate +from collections.abc import Mapping, Sequence +from torch_geometric.data.data import BaseData + +import torch +import mplhep as hep +import matplotlib.pyplot as plt +import os +import pickle as pkl +import math +import time +import tqdm +import numpy as np +import pandas as pd +import sklearn +import matplotlib +matplotlib.use("Agg") + +features_delphes = ["Track|cluster", "$p_{T}|E_{T}$", r"$\eta$", r'$Sin(\phi)$', r'$Cos(\phi)$', + "P|E", r"$\eta_\mathrm{out}|E_{em}$", r"$Sin(\(phi)_\mathrm{out}|E_{had}$", r"$Cos(\phi)_\mathrm{out}|E_{had}$", + "charge", "is_gen_mu", "is_gen_el"] + +features_cms = [ + "typ_idx", "pt", "eta", "phi", "e", + "layer", "depth", "charge", "trajpoint", + "eta_ecal", "phi_ecal", "eta_hcal", "phi_hcal", "muon_dt_hits", "muon_csc_hits", "muon_type", + "px", "py", "pz", "deltap", "sigmadeltap", + "gsf_electronseed_trkorecal", + "gsf_electronseed_dnn1", + "gsf_electronseed_dnn2", + "gsf_electronseed_dnn3", + "gsf_electronseed_dnn4", + "gsf_electronseed_dnn5", + "num_hits", "cluster_flags", "corr_energy", + "corr_energy_err", "vx", "vy", "vz", "pterror", "etaerror", "phierror", "lambd", "lambdaerror", "theta", "thetaerror" +] + +target_p4 = [ + "charge", + "pt", + "eta", + "sin_phi", + "cos_phi", + "energy", +] + + +def one_hot_embedding(labels, num_classes): + """ + Embedding labels to one-hot form. + + Args: + labels: (LongTensor) class labels, sized [N,]. + num_classes: (int) number of classes. + Returns: + (tensor) encoded labels, sized [N, #classes]. + """ + y = torch.eye(num_classes) + return y[labels] + + +def save_model(args, model_fname, outpath, model_kwargs): + + if not osp.isdir(outpath): + os.makedirs(outpath) + + else: # if directory already exists + if not args.overwrite: # if not overwrite then exit + print(f'model {model_fname} already exists, please delete it') + sys.exit(0) + + print(f'model {model_fname} already exists, deleting it') + + filelist = [f for f in os.listdir(outpath) if not f.endswith(".txt")] # don't remove the newly created logs.txt + for f in filelist: + try: + os.remove(os.path.join(outpath, f)) + except: + shutil.rmtree(os.path.join(outpath, f)) + + with open(f'{outpath}/model_kwargs.pkl', 'wb') as f: # dump model architecture + pkl.dump(model_kwargs, f, protocol=pkl.HIGHEST_PROTOCOL) + + with open(f'{outpath}/hyperparameters.json', 'w') as fp: # dump hyperparameters + json.dump({'data': args.data, + 'target': args.target, + 'n_train': args.n_train, + 'n_valid': args.n_valid, + 'n_test': args.n_test, + 'n_epochs': args.n_epochs, + 'lr': args.lr, + 'batch_size': args.batch_size, + 'alpha': args.alpha, + 'nearest': args.nearest, + 'num_convs': args.num_convs, + 'space_dim': args.space_dim, + 'propagate_dim': args.propagate_dim, + 'embedding_dim': args.embedding_dim, + 'hidden_dim1': args.hidden_dim1, + 'hidden_dim2': args.hidden_dim2, + }, fp) + + +def load_model(device, outpath, model_directory, load_epoch): + if load_epoch == -1: + PATH = outpath + '/best_epoch_weights.pth' + else: + PATH = outpath + '/epoch_' + str(load_epoch) + '_weights.pth' + + print('Loading a previously trained model..') + with open(outpath + '/model_kwargs.pkl', 'rb') as f: + model_kwargs = pkl.load(f) + + state_dict = torch.load(PATH, map_location=device) + + # # if the model was trained using DataParallel then we do this + # state_dict = torch.load(PATH, map_location=device) + # from collections import OrderedDict + # new_state_dict = OrderedDict() + # for k, v in state_dict.items(): + # name = k[7:] # remove module. + # new_state_dict[name] = v + # state_dict = new_state_dict + + return state_dict, model_kwargs, outpath + + +def make_plot_from_lists(title, xaxis, yaxis, save_as, X, Xlabel, X_save_as, outpath): + """ + Given a list A of lists B, makes a scatter plot of each list B and saves it. + """ + + if not os.path.exists(outpath): + os.makedirs(outpath) + + fig, ax = plt.subplots() + for i, var in enumerate(X): + ax.plot(range(len(var)), var, label=Xlabel[i]) + ax.set_xlabel(xaxis) + ax.set_ylabel(yaxis) + ax.legend(loc='best') + ax.set_title(title, fontsize=20) + plt.savefig(outpath + save_as + '.pdf') + plt.close(fig) + + for i, var in enumerate(X): + with open(outpath + X_save_as[i] + '.pkl', 'wb') as f: + pkl.dump(var, f) + + +def define_regions(num_eta_regions=10, num_phi_regions=10, max_eta=5, min_eta=-5, max_phi=1.5, min_phi=-1.5): + """ + Defines regions in (eta,phi) space to make bins within an event and build graphs within these bins. + + Returns + regions: a list of tuples ~ (eta_tuples, phi_tuples) where eta_tuples is a tuple ~ (eta_min, eta_max) that defines the limits of a region and equivalenelty phi + """ + eta_step = (max_eta - min_eta) / num_eta_regions + phi_step = (max_phi - min_phi) / num_phi_regions + + tuples_eta = [] + for j in range(num_eta_regions): + tuple = (min_eta + eta_step * (j), min_eta + eta_step * (j + 1)) + tuples_eta.append(tuple) + + tuples_phi = [] + for i in range(num_phi_regions): + tuple = (min_phi + phi_step * (i), min_phi + phi_step * (i + 1)) + tuples_phi.append(tuple) + + # make regions + regions = [] + for i in range(len(tuples_eta)): + for j in range(len(tuples_phi)): + regions.append((tuples_eta[i], tuples_phi[j])) + + return regions + + +def batch_event_into_regions(data, regions): + """ + Given an event and a set of regions in (eta,phi) space, returns a binned version of the event. + + Args + data: a Batch() object containing the event and its information + regions: a tuple of tuples containing the defined regions to bin an event (see define_regions) + + Returns + data: a modified Batch() object of based on data, where data.batch seperates the events in the different bins + """ + + x = None + for region in range(len(regions)): + in_region_msk = (data.x[:, 2] > regions[region][0][0]) & (data.x[:, 2] < regions[region][0][1]) & (torch.arcsin(data.x[:, 3]) > regions[region][1][0]) & (torch.arcsin(data.x[:, 3]) < regions[region][1][1]) + + if in_region_msk.sum() != 0: # if region is not empty + if x == None: # first iteration + x = data.x[in_region_msk] + ygen = data.ygen[in_region_msk] + ygen_id = data.ygen_id[in_region_msk] + ycand = data.ycand[in_region_msk] + ycand_id = data.ycand_id[in_region_msk] + batch = region + torch.zeros([len(data.x[in_region_msk])]) # assumes events were already fed one at a time (i.e. batch_size=1) + else: + x = torch.cat([x, data.x[in_region_msk]]) + ygen = torch.cat([ygen, data.ygen[in_region_msk]]) + ygen_id = torch.cat([ygen_id, data.ygen_id[in_region_msk]]) + ycand = torch.cat([ycand, data.ycand[in_region_msk]]) + ycand_id = torch.cat([ycand_id, data.ycand_id[in_region_msk]]) + batch = torch.cat([batch, region + torch.zeros([len(data.x[in_region_msk])])]) # assumes events were already fed one at a time (i.e. batch_size=1) + + data = Batch(x=x, + ygen=ygen, + ygen_id=ygen_id, + ycand=ycand, + ycand_id=ycand_id, + batch=batch.long(), + ) + return data + + +class Collater: + """ + This function was copied from torch_geometric.loader.Dataloader() source code. + Edits were made such that the function can collate samples as a list of tuples of Data() objects instead of Batch() objects. + This is needed becase pyg Dataloaders do not handle num_workers>0 since Batch() objects cannot be directly serialized using pkl. + """ + + def __init__(self): + pass + + def __call__(self, batch): + elem = batch[0] + if isinstance(elem, BaseData): + return batch + + elif isinstance(elem, Sequence) and not isinstance(elem, str): + return [self(s) for s in zip(*batch)] + + raise TypeError(f'DataLoader found invalid type: {type(elem)}') + + +def make_file_loaders(world_size, dataset, num_files=1, num_workers=0, prefetch_factor=2): + """ + This function is only one line, but it's worth explaining why it's needed and what it's doing. + It uses native torch Dataloaders with a custom collate_fn that allows loading Data() objects from pt files in a fast way. + This is needed becase pyg Dataloaders do not handle num_workers>0 since Batch() objects cannot be directly serialized using pkl. + + Args: + world_size: number of gpus available + dataset: custom dataset + num_files: number of files to load with a single get() call + num_workers: number of workers to use for fetching files + prefetch_factor: number of files to fetch in advance + + Returns: + a torch iterable() that returns a list of 100 elements, each element is a tuple of size=num_files containing Data() objects + """ + if world_size > 0: + return torch.utils.data.DataLoader(dataset, num_files, shuffle=False, num_workers=num_workers, prefetch_factor=prefetch_factor, collate_fn=Collater(), pin_memory=True) + else: + return torch.utils.data.DataLoader(dataset, num_files, shuffle=False, num_workers=num_workers, prefetch_factor=prefetch_factor, collate_fn=Collater(), pin_memory=False) + + +def dataloader_ttbar(train_dataset, valid_dataset, batch_size): + """ + Builds training and validation dataloaders from a physics dataset for conveninet ML training + Args: + train_dataset: a PFGraphDataset dataset that is a list of lists that contain Data() objects + valid_dataset: a PFGraphDataset dataset that is a list of lists that contain Data() objects + Returns: + train_loader: a pyg iterable DataLoader() that contains Batch objects for training + valid_loader: a pyg iterable DataLoader() that contains Batch objects for validation + """ + + # preprocessing the train_dataset in a good format for passing correct batches of events to the GNN + train_data = [] + for data in train_dataset: + train_data = train_data + data + + # preprocessing the valid_dataset in a good format for passing correct batches of events to the GNN + valid_data = [] + for data in valid_dataset: + valid_data = valid_data + data + + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) + valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True) + + return train_loader, valid_loader + + +def dataloader_qcd(multi_gpu, test_dataset, batch_size): + """ + Builds a testing dataloader from a physics dataset for conveninet ML training + Args: + test_dataset: a PFGraphDataset dataset that is a list of lists that contain Data() objects + Returns: + test_loader: a pyg iterable DataLoader() that contains Batch objects for testing + """ + + # preprocessing the test_dataset in a good format for passing correct batches of events to the GNN + test_data = [] + for data in test_dataset: + test_data = test_data + data + + if not multi_gpu: + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) + else: + test_loader = DataListLoader(test_data, batch_size=batch_size, shuffle=True) + + return test_loader diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py new file mode 100644 index 000000000..919998630 --- /dev/null +++ b/mlpf/pyg_pipeline.py @@ -0,0 +1,316 @@ +from pyg import parse_args +from pyg import PFGraphDataset +from pyg import MLPF, training_loop, make_predictions, postprocess_predictions, make_plots_cms +from pyg import save_model, load_model +from pyg import features_delphes, features_cms, target_p4 +from pyg import make_file_loaders + +import torch +import torch_geometric +from torch_geometric.loader import DataLoader, DataListLoader +from torch_geometric.data import Data, Batch + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +import mplhep as hep +import time +import matplotlib.pyplot as plt +from glob import glob +import sys +import os +import os.path as osp +import shutil +import pickle as pkl +import json +import math +import time +import tqdm +import numpy as np +import pandas as pd +import sklearn +import matplotlib +matplotlib.use("Agg") + + +""" +Developing a PyTorch Geometric MLPF pipeline using DistributedDataParallel. + +Author: Farouk Mokhtar +""" + +# Ignore divide by 0 errors +np.seterr(divide='ignore', invalid='ignore') + +# define the global base device +if torch.cuda.device_count(): + device = torch.device('cuda:0') +else: + device = 'cpu' + + +def setup(rank, world_size): + """ + Necessary setup function that sets up environment variables and initializes the process group to perform training & inference using DistributedDataParallel (DDP). + DDP relies on c10d ProcessGroup for communications, hence, applications must create ProcessGroup instances before constructing DDP. + + Args: + rank: the process id (or equivalently the gpu index) + world_size: number of gpus available + """ + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # dist.init_process_group("gloo", rank=rank, world_size=world_size) + dist.init_process_group("nccl", rank=rank, world_size=world_size) # should be faster for DistributedDataParallel on gpus + + +def cleanup(): + """ + Necessary function that destroys the spawned process group at the end. + """ + + dist.destroy_process_group() + + +def run_demo(demo_fn, world_size, args, dataset, model, num_classes, outpath): + """ + Necessary function that spawns a process group of size=world_size processes to run demo_fn() on each gpu device that will be indexed by 'rank'. + + Args: + demo_fn: function you wish to run on each gpu + world_size: number of gpus available + mode: 'train' or 'inference' + """ + + # mp.set_start_method('forkserver') + + mp.spawn(demo_fn, + args=(world_size, args, dataset, model, num_classes, outpath), + nprocs=world_size, + join=True, + ) + + +def train_ddp(rank, world_size, args, dataset, model, num_classes, outpath): + """ + A train_ddp() function that will be passed as a demo_fn to run_demo() to perform training over multiple gpus using DDP. + + It divides and distributes the training dataset appropriately, copies the model, and wraps the model with DDP on each device + to allow synching of gradients, and finally, invokes the training_loop() to run synchronized training among devices. + """ + + setup(rank, world_size) + + print(f"Running training on rank {rank}: {torch.cuda.get_device_name(rank)}") + + # give each gpu a subset of the data + hyper_train = int(args.n_train / world_size) + hyper_valid = int(args.n_valid / world_size) + + train_dataset = torch.utils.data.Subset(dataset, np.arange(start=rank * hyper_train, stop=(rank + 1) * hyper_train)) + valid_dataset = torch.utils.data.Subset(dataset, np.arange(start=args.n_train + rank * hyper_valid, stop=args.n_train + (rank + 1) * hyper_valid)) + + # construct file loaders + file_loader_train = make_file_loaders(world_size, train_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor) + file_loader_valid = make_file_loaders(world_size, valid_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor) + + # copy the model to the GPU with id=rank + print(f'Copying the model on rank {rank}..') + model = model.to(rank) + model.train() + ddp_model = DDP(model, device_ids=[rank]) + + optimizer = torch.optim.Adam(ddp_model.parameters(), lr=args.lr) + + training_loop(rank, args.data, ddp_model, file_loader_train, file_loader_valid, + args.batch_size, args.n_epochs, args.patience, + optimizer, args.alpha, args.target, num_classes, outpath) + + cleanup() + + +def inference_ddp(rank, world_size, args, dataset, model, num_classes, PATH): + """ + An inference_ddp() function that will be passed as a demo_fn to run_demo() to perform inference over multiple gpus using DDP. + + It divides and distributes the testing dataset appropriately, copies the model, and wraps the model with DDP on each device. + """ + + setup(rank, world_size) + + print(f"Running inference on rank {rank}: {torch.cuda.get_device_name(rank)}") + + # give each gpu a subset of the data + hyper_test = int(args.n_test / world_size) + + test_dataset = torch.utils.data.Subset(dataset, np.arange(start=rank * hyper_test, stop=(rank + 1) * hyper_test)) + + # construct data loaders + file_loader_test = make_file_loaders(world_size, test_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor) + + # copy the model to the GPU with id=rank + print(f'Copying the model on rank {rank}..') + model = model.to(rank) + model.eval() + ddp_model = DDP(model, device_ids=[rank]) + + make_predictions(rank, ddp_model, file_loader_test, args.batch_size, num_classes, PATH) + + cleanup() + + +def train(device, world_size, args, dataset, model, num_classes, outpath): + """ + A train() function that will load the training dataset and start a training_loop on a single device (cuda or cpu). + """ + + if device == 'cpu': + print(f"Running training on cpu") + else: + print(f"Running training on: {torch.cuda.get_device_name(device)}") + device = device.index + + train_dataset = torch.utils.data.Subset(dataset, np.arange(start=0, stop=args.n_train)) + valid_dataset = torch.utils.data.Subset(dataset, np.arange(start=args.n_train, stop=args.n_train + args.n_valid)) + + # construct file loaders + file_loader_train = make_file_loaders(world_size, train_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor) + file_loader_valid = make_file_loaders(world_size, valid_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor) + + # move the model to the device (cuda or cpu) + model = model.to(device) + model.train() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + training_loop(device, args.data, model, file_loader_train, file_loader_valid, + args.batch_size, args.n_epochs, args.patience, + optimizer, args.alpha, args.target, num_classes, outpath) + + +def inference(device, world_size, args, dataset, model, num_classes, PATH): + """ + An inference() function that will load the testing dataset and start running inference on a single device (cuda or cpu). + """ + + if device == 'cpu': + print(f"Running inference on cpu") + else: + print(f"Running inference on: {torch.cuda.get_device_name(device)}") + device = device.index + + test_dataset = torch.utils.data.Subset(dataset, np.arange(start=0, stop=args.n_test)) + + # construct data loaders + file_loader_test = make_file_loaders(world_size, test_dataset, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor) + + # copy the model to the GPU with id=rank + model = model.to(device) + model.eval() + + make_predictions(device, model, file_loader_test, args.batch_size, num_classes, PATH) + + +if __name__ == "__main__": + + args = parse_args() + + world_size = torch.cuda.device_count() + + torch.backends.cudnn.benchmark = True + + # retrieve the dimensions of the PF-elements & PF-candidates to set the input/output dimension of the model + if args.data == 'delphes': + input_dim = len(features_delphes) + num_classes = 6 # we have 6 classes/pids for delphes + elif args.data == 'cms': + input_dim = len(features_cms) + num_classes = 9 # we have 9 classes/pids for cms (including taus) + output_dim_p4 = len(target_p4) + + outpath = osp.join(args.outpath, args.model_prefix) + + # load a pre-trained specified model, otherwise, instantiate and train a new model + if args.load: + state_dict, model_kwargs, outpath = load_model(device, outpath, args.model_prefix, args.load_epoch) + + model = MLPF(**model_kwargs) + model.load_state_dict(state_dict) + + else: + model_kwargs = {'input_dim': input_dim, + 'num_classes': num_classes, + 'output_dim_p4': output_dim_p4, + 'embedding_dim': args.embedding_dim, + 'hidden_dim1': args.hidden_dim1, + 'hidden_dim2': args.hidden_dim2, + 'num_convs': args.num_convs, + 'space_dim': args.space_dim, + 'propagate_dim': args.propagate_dim, + 'k': args.nearest, + } + + model = MLPF(**model_kwargs) + + # save model_kwargs and hyperparameters + save_model(args, args.model_prefix, outpath, model_kwargs) + + print(model) + print(args.model_prefix) + + print("Training over {} epochs".format(args.n_epochs)) + + # run the training using DDP if more than one gpu is available + dataset = PFGraphDataset(args.dataset, args.data) + + if world_size >= 2: + run_demo(train_ddp, world_size, args, dataset, model, num_classes, outpath) + else: + train(device, world_size, args, dataset, model, num_classes, outpath) + + # load the best epoch state + state_dict = torch.load(outpath + '/best_epoch_weights.pth', map_location=device) + model.load_state_dict(state_dict) + + # specify which epoch/state to load to run the inference and make plots + if args.load and args.load_epoch != -1: + epoch_to_load = args.load_epoch + else: + import json + epoch_to_load = json.load(open(f'{outpath}/best_epoch.json'))['best_epoch'] + + PATH = f'{outpath}/testing_epoch_{epoch_to_load}_{args.sample}/' + pred_path = f'{PATH}/predictions/' + plot_path = f'{PATH}/plots/' + + # run the inference + if args.make_predictions: + + if not os.path.exists(PATH): + os.makedirs(PATH) + if not os.path.exists(f'{PATH}/predictions/'): + os.makedirs(f'{PATH}/predictions/') + if not os.path.exists(f'{PATH}/plots/'): + os.makedirs(f'{PATH}/plots/') + + # run the inference using DDP if more than one gpu is available + dataset_test = PFGraphDataset(args.dataset_test, args.data) + + if world_size >= 2: + run_demo(inference_ddp, world_size, args, dataset_test, model, num_classes, PATH) + else: + inference(device, world_size, args, dataset_test, model, num_classes, PATH) + + postprocess_predictions(pred_path) + + # load the predictions and make plots (must have ran make_predictions before) + if args.make_plots: + + if not osp.isdir(plot_path): + os.makedirs(plot_path) + + if args.data == 'cms': + make_plots_cms(pred_path, plot_path, args.sample) diff --git a/mlpf/pytorch_cms/README.md b/mlpf/pytorch_cms/README.md deleted file mode 100644 index 19ab76c6e..000000000 --- a/mlpf/pytorch_cms/README.md +++ /dev/null @@ -1,5 +0,0 @@ -Short instructions to train on cms data: -```bash -cd ../.. -./scripts/local_test_cms.sh -``` diff --git a/mlpf/pytorch_cms/eval_end2end_cms.py b/mlpf/pytorch_cms/eval_end2end_cms.py deleted file mode 100644 index 783933fd1..000000000 --- a/mlpf/pytorch_cms/eval_end2end_cms.py +++ /dev/null @@ -1,67 +0,0 @@ -#import setGPU -import torch -import torch_geometric -import sklearn -import numpy as np -import matplotlib.pyplot as plt -from torch_geometric.data import Data, DataLoader, DataListLoader, Batch -import pandas -import mplhep -import pickle - -import graph_data_cms -import train_end2end_cms -import time - -def collate(items): - l = sum(items, []) - return Batch.from_data_list(l) - - -def parse_args(): - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, choices=sorted(train_end2end_cms.model_classes.keys()), help="type of model to use", default="PFNet6") - parser.add_argument("--path", type=str, help="path to model", default="data/PFNet7_TTbar_14TeV_TuneCUETP8M1_cfi_gen__npar_221073__cfg_ee19d91068__user_jovyan__ntrain_400__lr_0.0001__1588215695") - parser.add_argument("--epoch", type=str, default=0, help="Epoch to use") - parser.add_argument("--dataset", type=str, help="Input dataset", required=True) - parser.add_argument("--start", type=int, default=3800, help="first file index to evaluate") - parser.add_argument("--stop", type=int, default=4000, help="last file index to evaluate") - args = parser.parse_args() - return args - -if __name__ == "__main__": - args = parse_args() - device = torch.device("cpu") - - epoch = args.epoch - model = args.model - path = args.path - weights = torch.load("{}/epoch_{}_weights.pth".format(path, epoch), map_location=device) - weights = {k.replace("module.", ""): v for k, v in weights.items()} - - with open('{}/model_kwargs.pkl'.format(path),'rb') as f: - model_kwargs = pickle.load(f) - - model_class = train_end2end_cms.model_classes[args.model] - model = model_class(**model_kwargs) - model.load_state_dict(weights) - model = model.to(device) - model.eval() - - - print(args.dataset) - full_dataset = graph_data_cms.PFGraphDataset(root=args.dataset) - print("full_dataset", len(full_dataset)) - test_dataset = torch.utils.data.Subset(full_dataset, np.arange(start=args.start, stop=args.stop)) - assert(len(test_dataset)>0) - - loader = DataListLoader(test_dataset, batch_size=1, pin_memory=False, shuffle=False) - loader.collate_fn = collate - - big_df = train_end2end_cms.prepare_dataframe(model, loader, False, device) - - big_df.to_pickle("{}/df.pkl.bz2".format(path)) - #edges_df.to_csv("{}/edges.csv".format(path)) - print(big_df) - #print(edges_df) diff --git a/mlpf/pytorch_cms/graph_data_cms.py b/mlpf/pytorch_cms/graph_data_cms.py deleted file mode 100644 index 3770c16fa..000000000 --- a/mlpf/pytorch_cms/graph_data_cms.py +++ /dev/null @@ -1,202 +0,0 @@ -import numpy as np -import os -import os.path as osp -import torch -import torch_geometric -import torch_geometric.utils -from torch_geometric.data import Dataset, Data, Batch -import itertools -from glob import glob -import numba -from numpy.lib.recfunctions import append_fields - -import pickle -import scipy -import scipy.sparse -import math -import multiprocessing - -elem_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] -class_labels = [0, 1, 2, 11, 13, 22, 130, 211] - -#map these to ids 0...Nclass -class_to_id = {r: class_labels[r] for r in range(len(class_labels))} - -# map these to ids 0...Nclass -elem_to_id = {r: elem_labels[r] for r in range(len(elem_labels))} - -# Data normalization constants for faster convergence. -# These are just estimated with a printout and rounding, don't need to be super accurate -# x_means = torch.tensor([ 0.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).to(device) -# x_stds = torch.tensor([ 1.0, 22.0, 2.6, 1.8, 1.3, 1.9, 1.3, 1.0]).to(device) -# y_candidates_means = torch.tensor([0.0, 0.0, 0.0]).to(device) -# y_candidates_stds = torch.tensor([1.8, 2.0, 1.5]).to(device) -def process_func(args): - self, fns, idx_file = args - return self.process_multiple_files(fns, idx_file) - -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i:i + n] - -#Do any in-memory transformations to data -def data_prep(data, device=torch.device('cpu')): - #Create a one-hot encoded vector of the class labels - data.y_candidates_id = data.ycand[:, 0].to(dtype=torch.long) - data.y_gen_id = data.ygen[:, 0].to(dtype=torch.long) - - #one-hot encode the input categorical of the input - elem_id_onehot = torch.nn.functional.one_hot(data.x[:, 0].to(dtype=torch.long), num_classes=len(elem_to_id)) - data.x = torch.cat([elem_id_onehot.to(dtype=torch.float), data.x[:, 1:]], axis=-1) - - data.y_candidates_weights = torch.ones(len(class_to_id)).to(device=device, dtype=torch.float) - data.y_gen_weights = torch.ones(len(class_to_id)).to(device=device, dtype=torch.float) - - data.ycand = data.ycand[:, 1:] - data.ygen = data.ygen[:, 1:] - - data.x[torch.isnan(data.x)] = 0.0 - data.ycand[torch.isnan(data.ycand)] = 0.0 - data.ygen[torch.isnan(data.ygen)] = 0.0 - data.ygen[data.ygen.abs()>1e4] = 0 - #print("x=", data.x) - #print("y_candidates_id=", data.y_candidates_id) - #print("y_gen_id=", data.y_gen_id) - #print("ycand=", data.ycand) - #print("ygen=", data.ygen) - -class PFGraphDataset(Dataset): - def __init__(self, root, transform=None, pre_transform=None): - super(PFGraphDataset, self).__init__(root, transform, pre_transform) - self._processed_dir = Dataset.processed_dir.fget(self) - - @property - def raw_file_names(self): - raw_list = glob(osp.join(self.raw_dir, '*.pkl')) - print("PFGraphDataset nfiles={}".format(len(raw_list))) - return sorted([l.replace(self.raw_dir, '.') for l in raw_list]) - - def _download(self): - pass - - def _process(self): - pass - - @property - def processed_dir(self): - return self._processed_dir - - @property - def processed_file_names(self): - proc_list = glob(osp.join(self.processed_dir, '*.pt')) - return sorted([l.replace(self.processed_dir, '.') for l in proc_list]) - - def __len__(self): - return len(self.processed_file_names) - - def download(self): - # Download to `self.raw_dir`. - pass - - def process_single_file(self, raw_file_name): - with open(osp.join(self.raw_dir, raw_file_name), "rb") as fi: - all_data = pickle.load(fi, encoding='iso-8859-1') - - batch_data = [] - - # all_data is a list of only one element.. this element is a dictionary with keys: ["Xelem", "ycan", "ygen", 'dm', 'dm_elem_cand', 'dm_elem_gen'] - data = all_data[0] - mat = data["dm_elem_cand"].copy() - - # Xelem contains all elements in 1 event - # Xelem[i] contains the element #i in the event - Xelem = data["Xelem"] - ygen = data["ygen"] - ycand = data["ycand"] - - # attach to every Xelem[i] (which is one element in the event) an extra elem_label - Xelem = append_fields(Xelem, "typ_idx", np.array([elem_labels.index(int(i)) for i in Xelem["typ"]], dtype=np.float32)) - ygen = append_fields(ygen, "typ_idx", np.array([class_labels.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32)) - ycand = append_fields(ycand, "typ_idx", np.array([class_labels.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32)) - - Xelem_flat = np.stack([Xelem[k].view(np.float32).data for k in [ - 'typ_idx', - 'pt', 'eta', 'phi', 'e', - 'layer', 'depth', 'charge', 'trajpoint', - 'eta_ecal', 'phi_ecal', 'eta_hcal', 'phi_hcal', - 'muon_dt_hits', 'muon_csc_hits']], axis=-1 - ) - ygen_flat = np.stack([ygen[k].view(np.float32).data for k in [ - 'typ_idx', - 'eta', 'phi', 'e', 'charge', - ]], axis=-1 - ) - ycand_flat = np.stack([ycand[k].view(np.float32).data for k in [ - 'typ_idx', - 'eta', 'phi', 'e', 'charge', - ]], axis=-1 - ) - r = torch_geometric.utils.from_scipy_sparse_matrix(mat) - - x = torch.tensor(Xelem_flat, dtype=torch.float) - ygen = torch.tensor(ygen_flat, dtype=torch.float) - ycand = torch.tensor(ycand_flat, dtype=torch.float) - - data = Data( - x=x, - edge_index=r[0].to(dtype=torch.long), - #edge_attr=r[1].to(dtype=torch.float), - ygen=ygen, ycand=ycand, - ) - data_prep(data) - batch_data += [data] - - return batch_data - - def process_multiple_files(self, filenames, idx_file): - datas = [self.process_single_file(fn) for fn in filenames] - datas = sum(datas, []) - p = osp.join(self.processed_dir, 'data_{}.pt'.format(idx_file)) - print(p) - torch.save(datas, p) - - def process(self, num_files_to_batch): - idx_file = 0 - for fns in chunks(self.raw_file_names, num_files_to_batch): - self.process_multiple_files(fns, idx_file) - idx_file += 1 - - def process_parallel(self, num_files_to_batch, num_proc): - pars = [] - idx_file = 0 - for fns in chunks(self.raw_file_names, num_files_to_batch): - pars += [(self, fns, idx_file)] - idx_file += 1 - pool = multiprocessing.Pool(num_proc) - pool.map(process_func, pars) - - def get(self, idx): - p = osp.join(self.processed_dir, 'data_{}.pt'.format(idx)) - data = torch.load(p) - return data - - def __getitem__(self, idx): - return self.get(idx) - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--dataset", type=str, help="dataset path", required=True) - parser.add_argument("--processed_dir", type=str, help="processed", required=False, default=None) - parser.add_argument("--num-files-merge", type=int, default=10, help="number of files to merge") - parser.add_argument("--num-proc", type=int, default=24, help="number of processes") - args = parser.parse_args() - - pfgraphdataset = PFGraphDataset(root=args.dataset) - - if args.processed_dir: - pfgraphdataset._processed_dir = args.processed_dir - - pfgraphdataset.process_parallel(args.num_files_merge,args.num_proc) - #pfgraphdataset.process(args.num_files_merge) diff --git a/mlpf/pytorch_cms/gravnet.py b/mlpf/pytorch_cms/gravnet.py deleted file mode 100644 index 5259da3c6..000000000 --- a/mlpf/pytorch_cms/gravnet.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -from torch.nn import Linear -from torch_scatter import scatter, segment_csr -from torch_geometric.nn.conv import MessagePassing - -try: - from torch_cluster import knn_graph - from torch_cluster import radius_graph -except ImportError: - knn_graph = None - - -class GravNetConv(MessagePassing): - r"""The GravNet operator from the `"Learning Representations of Irregular - Particle-detector Geometry with Distance-weighted Graph - Networks" `_ paper, where the graph is - dynamically constructed using nearest neighbors. - The neighbors are constructed in a learnable low-dimensional projection of - the feature space. - A second projection of the input feature space is then propagated from the - neighbors to each vertex using distance weights that are derived by - applying a Gaussian function to the distances. - Args: - in_channels (int): The number of input channels. - out_channels (int): The number of output channels. - space_dimensions (int): The dimensionality of the space used to - construct the neighbors; referred to as :math:`S` in the paper. - propagate_dimensions (int): The number of features to be propagated - between the vertices; referred to as :math:`F_{\textrm{LR}}` in the - paper. - k (int): The number of nearest neighbors. - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - """ - - def __init__(self, in_channels, out_channels, space_dimensions, - propagate_dimensions, k, neighbor_algo="knn", radius=0.1, **kwargs): - super(GravNetConv, self).__init__(**kwargs) - - if knn_graph is None: - raise ImportError('`GravNetConv` requires `torch-cluster`.') - - self.in_channels = in_channels - self.out_channels = out_channels - self.k = k - self.radius = radius - - self.lin_s = Linear(in_channels, space_dimensions) - self.lin_flr = Linear(in_channels, propagate_dimensions) - self.lin_fout = Linear(in_channels + 2 * propagate_dimensions, - out_channels) - - self.neighbor_algo = neighbor_algo - self.reset_parameters() - - def reset_parameters(self): - self.lin_s.reset_parameters() - self.lin_flr.reset_parameters() - self.lin_fout.reset_parameters() - - def forward(self, x, batch=None): - spatial = self.lin_s(x) - to_propagate = self.lin_flr(x) - - if self.neighbor_algo == "knn": - edge_index = knn_graph(spatial, self.k, batch, loop=False, - flow=self.flow, cosine=False) - elif self.neighbor_algo == "radius": - edge_index = radius_graph(spatial, self.radius, batch, loop=False, - flow=self.flow, max_num_neighbors=self.k) - else: - raise Exception("Unknown neighbor algo {}".format(self.neighbor_algo)) - - reference = spatial.index_select(0, edge_index[1]) - neighbors = spatial.index_select(0, edge_index[0]) - - distancessq = torch.sum((reference - neighbors)**2, dim=-1) - # Factor 10 gives a better initial spread - distance_weight = torch.exp(-10. * distancessq) - - prop_feat = self.propagate(edge_index, x=to_propagate, - edge_weight=distance_weight) - - return edge_index, self.lin_fout(torch.cat([prop_feat, x], dim=-1)) - - def message(self, x_j, edge_weight): - return x_j * edge_weight.unsqueeze(1) - - def aggregate(self, inputs, index, ptr=None, dim_size=None): - if ptr is not None: - for _ in range(self.node_dim): - ptr = ptr.unsqueeze(0) - aggr_mean = segment_csr(inputs, ptr, reduce='mean') - aggr_max = segment_csr(inputs, ptr, reduce='max') - else: - aggr_mean = scatter(inputs, index, dim=self.node_dim, - dim_size=dim_size, reduce='mean') - aggr_max = scatter(inputs, index, dim=self.node_dim, - dim_size=dim_size, reduce='max') - - return torch.cat([aggr_mean, aggr_max], dim=-1) - - def __repr__(self): - return '{}({}, {}, k={})'.format(self.__class__.__name__, - self.in_channels, self.out_channels, - self.k) diff --git a/mlpf/pytorch_cms/train_end2end_cms.py b/mlpf/pytorch_cms/train_end2end_cms.py deleted file mode 100644 index 2b0802344..000000000 --- a/mlpf/pytorch_cms/train_end2end_cms.py +++ /dev/null @@ -1,670 +0,0 @@ -import sys -import os -import math - -from comet_ml import Experiment - -#Check if the GPU configuration has been provided -try: - if not ("CUDA_VISIBLE_DEVICES" in os.environ): - import setGPU -except Exception as e: - print("Could not import setGPU, running CPU-only") - -import torch -use_gpu = torch.cuda.device_count()>0 -multi_gpu = torch.cuda.device_count()>1 - -#define the global base device -if use_gpu: - device = torch.device('cuda:0') -else: - device = torch.device('cpu') - -import torch_geometric - -import torch.nn as nn -import torch.nn.functional as F -import torch_geometric.transforms as T -from torch_geometric.nn import EdgeConv, MessagePassing, EdgePooling, GATConv, GCNConv, JumpingKnowledge, GraphUNet, DynamicEdgeConv, DenseGCNConv -from torch_geometric.nn import TopKPooling, SAGPooling, SGConv -from torch.nn import Sequential as Seq, Linear as Lin, ReLU -from torch_scatter import scatter_mean -from torch_geometric.nn.inits import reset -from torch_geometric.data import Data, DataLoader, DataListLoader, Batch -from gravnet import GravNetConv - -import torch_cluster - -from glob import glob -import numpy as np -import os.path as osp -import pickle - -import math -import time -import numba -import tqdm -import sklearn -import pandas - -import matplotlib -matplotlib.use("Agg") -import matplotlib.pyplot as plt -import mplhep - -from sklearn.metrics import accuracy_score -from sklearn.metrics import confusion_matrix - -import graph_data_cms -from graph_data_cms import PFGraphDataset, elem_to_id, class_to_id, class_labels - -#Ignore divide by 0 errors -np.seterr(divide='ignore', invalid='ignore') - -def onehot(a): - b = np.zeros((a.size, len(class_labels))) - b[np.arange(a.size),a] = 1 - return b - -#Creates the dataframe of predictions given a trained model and a data loader -def prepare_dataframe(model, loader, multi_gpu, device): - model.eval() - dfs = [] - dfs_edges = [] - eval_time = 0 - - for i, data in enumerate(loader): - if not multi_gpu: - data = data.to(device) - - pred_id_onehot, pred_momentum, new_edges = model(data) - _, pred_id = torch.max(pred_id_onehot, -1) - pred_momentum[pred_id==0] = 0 - if not multi_gpu: - data = [data] - - x = torch.cat([d.x.to("cpu") for d in data]) - gen_id = torch.cat([d.y_gen_id.to("cpu") for d in data]) - gen_p4 = torch.cat([d.ygen[:, :4].to("cpu") for d in data]) - cand_id = torch.cat([d.y_candidates_id.to("cpu") for d in data]) - cand_p4 = torch.cat([d.ycand[:, :4].to("cpu") for d in data]) - - df = pandas.DataFrame() - - df["elem_type"] = [int(graph_data_cms.elem_labels[i]) for i in torch.argmax(x[:, :len(graph_data_cms.elem_labels)], axis=-1).numpy()] - for ifeat, feat in enumerate([ - 'pt', 'eta', 'phi', 'e', - 'layer', 'depth', 'charge', 'trajpoint', - 'eta_ecal', 'phi_ecal', 'eta_hcal', 'phi_hcal', - 'muon_dt_hits', 'muon_csc_hits']): - df["elem_{}".format(feat)] = x[:, len(graph_data_cms.elem_labels)+ifeat].numpy() - - df["elem_type"] = [int(graph_data_cms.elem_labels[i]) for i in torch.argmax(x[:, :len(graph_data_cms.elem_labels)], axis=-1).numpy()] - df["gen_pid"] = [int(graph_data_cms.class_labels[i]) for i in gen_id.numpy()] - df["gen_eta"] = gen_p4[:, 0].numpy() - df["gen_phi"] = gen_p4[:, 1].numpy() - df["gen_e"] = gen_p4[:, 2].numpy() - df["gen_charge"] = gen_p4[:, 3].numpy() - - df["cand_pid"] = [int(graph_data_cms.class_labels[i]) for i in cand_id.numpy()] - df["cand_eta"] = cand_p4[:, 0].numpy() - df["cand_phi"] = cand_p4[:, 1].numpy() - df["cand_e"] = cand_p4[:, 2].numpy() - df["cand_charge"] = cand_p4[:, 3].numpy() - df["pred_pid"] = [int(graph_data_cms.class_labels[i]) for i in pred_id.detach().cpu().numpy()] - - df["pred_eta"] = pred_momentum[:, 0].detach().cpu().numpy() - df["pred_phi"] = pred_momentum[:, 1].detach().cpu().numpy() - df["pred_e"] = pred_momentum[:, 2].detach().cpu().numpy() - df["pred_charge"] = pred_momentum[:, 3].detach().cpu().numpy() - - dfs.append(df) - #df_edges = pandas.DataFrame() - #df_edges["edge0"] = edges[0].to("cpu") - #df_edges["edge1"] = edges[1].to("cpu") - #dfs_edges += [df_edges] - - df = pandas.concat(dfs, ignore_index=True) - #df_edges = pandas.concat(dfs_edges, ignore_index=True) - return df#, df_edges - -#Get a unique directory name for the model -def get_model_fname(dataset, model, n_train, lr, target_type): - model_name = type(model).__name__ - model_params = sum(p.numel() for p in model.parameters()) - import hashlib - model_cfghash = hashlib.blake2b(repr(model).encode()).hexdigest()[:10] - model_user = os.environ['USER'] - - model_fname = '{}_{}_{}__npar_{}__cfg_{}__user_{}__ntrain_{}__lr_{}__{}'.format( - model_name, - dataset.split("/")[-1], - target_type, - model_params, - model_cfghash, - model_user, - n_train, - lr, int(time.time())) - return model_fname - -#Model with gravnet clustering -class PFNet7(nn.Module): - def __init__(self, - input_dim=3, hidden_dim=32, encoding_dim=256, - output_dim_id=len(class_to_id), - output_dim_p4=4, - convlayer="gravnet-radius", - convlayer2="none", - space_dim=2, nearest=3, dropout_rate=0.0, activation="leaky_relu", return_edges=False, radius=0.1, input_encoding=0): - - super(PFNet7, self).__init__() - - self.input_dim = input_dim - self.hidden_dim = hidden_dim - self.return_edges = return_edges - self.convlayer = convlayer - self.input_encoding = input_encoding - - if activation == "leaky_relu": - self.act = nn.LeakyReLU - self.act_f = torch.nn.functional.leaky_relu - elif activation == "selu": - self.act = nn.SELU - self.act_f = torch.nn.functional.selu - elif activation == "relu": - self.act = nn.ReLU - self.act_f = torch.nn.functional.relu - - # if you want to add an initial encoding of the input - conv_in_dim = input_dim - if self.input_encoding>0: - self.nn1 = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - self.act(), - nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), - nn.Linear(hidden_dim, hidden_dim), - self.act(), - nn.Linear(hidden_dim, encoding_dim), - ) - conv_in_dim = encoding_dim - - # (1) GNN layer - if convlayer == "gravnet-knn": - self.conv1 = GravNetConv(conv_in_dim, encoding_dim, space_dim, hidden_dim, nearest, neighbor_algo="knn") - elif convlayer == "gravnet-radius": - self.conv1 = GravNetConv(conv_in_dim, encoding_dim, space_dim, hidden_dim, nearest, neighbor_algo="radius", radius=radius) - else: - raise Exception("Unknown convolution layer: {}".format(convlayer)) - - #decoding layer receives the raw inputs and the gravnet output - num_decode_in = input_dim + encoding_dim - - # (2) another GNN layer if you want - self.convlayer2 = convlayer2 - if convlayer2 == "none": - self.conv2_1 = None - self.conv2_2 = None - elif convlayer2 == "sgconv": - self.conv2_1 = SGConv(num_decode_in, hidden_dim, K=1) - self.conv2_2 = SGConv(num_decode_in, hidden_dim, K=1) - num_decode_in += hidden_dim - elif convlayer2 == "graphunet": - self.conv2_1 = GraphUNet(num_decode_in, hidden_dim, hidden_dim, 2, pool_ratios=0.1) - self.conv2_2 = GraphUNet(num_decode_in, hidden_dim, hidden_dim, 2, pool_ratios=0.1) - num_decode_in += hidden_dim - elif convlayer2 == "gatconv": - self.conv2_1 = GATConv(num_decode_in, hidden_dim, 4, concat=False, dropout=dropout_rate) - self.conv2_2 = GATConv(num_decode_in, hidden_dim, 4, concat=False, dropout=dropout_rate) - num_decode_in += hidden_dim - else: - raise Exception("Unknown convolution layer: {}".format(convlayer2)) - - # (3) dropout layer if you want - self.dropout1 = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity() - - # (4) DNN layer: classifying PID - self.nn2 = nn.Sequential( - nn.Linear(num_decode_in, hidden_dim), - self.act(), - nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), - nn.Linear(hidden_dim, hidden_dim), - self.act(), - nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), - nn.Linear(hidden_dim, hidden_dim), - self.act(), - nn.Linear(hidden_dim, output_dim_id), - ) - - # (5) DNN layer: regressing p4 - self.nn3 = nn.Sequential( - nn.Linear(num_decode_in + output_dim_id, hidden_dim), - self.act(), - nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), - nn.Linear(hidden_dim, hidden_dim), - self.act(), - nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(), - nn.Linear(hidden_dim, hidden_dim), - self.act(), - nn.Linear(hidden_dim, output_dim_p4), - ) - - - def forward(self, data): - - #encode the inputs - x = data.x - - if self.input_encoding: - x = self.nn1(x) - - #Run a clustering of the inputs that returns the new_edge_index - new_edge_index, x = self.conv1(x) - x1 = self.act_f(x) - - #run a second convolution - if self.convlayer2 != "none": - conv2_input = torch.cat([data.x, x1], axis=-1) - x2_1 = self.act_f(self.conv2_1(conv2_input, new_edge_index)) - x2_2 = self.act_f(self.conv2_2(conv2_input, new_edge_index)) - nn2_input = torch.cat([data.x, x1, x2_1], axis=-1) - else: - nn2_input = torch.cat([data.x, x1], axis=-1) - - #Decode convolved graph nodes to pdgid and p4 - cand_ids = self.nn2(self.dropout1(nn2_input)) - - if self.convlayer2 != "none": - nn3_input = torch.cat([data.x, x1, x2_2, cand_ids], axis=-1) - else: - nn3_input = torch.cat([data.x, x1, cand_ids], axis=-1) - - cand_p4 = data.x[:, len(elem_to_id):len(elem_to_id)+4] + self.nn3(self.dropout1(nn3_input)) - return cand_ids, cand_p4, new_edge_index - -model_classes = { - "PFNet7": PFNet7, -} - -def mse_loss(input, target): - return torch.sum((input - target) ** 2) - -def weighted_mse_loss(input, target, weight): - return torch.sum(weight * (input - target).sum(axis=1) ** 2) - -@torch.no_grad() -def test(model, loader, epoch, l1m, l2m, l3m, target_type): - with torch.no_grad(): - ret = train(model, loader, epoch, None, l1m, l2m, l3m, target_type, None) - return ret - -def compute_weights(target_ids, device): - vs, cs = torch.unique(target_ids, return_counts=True) - weights = torch.zeros(len(class_to_id)).to(device=device) - for k, v in zip(vs, cs): - weights[k] = 1.0/math.sqrt(float(v)) - return weights - - -def train(model, loader, epoch, optimizer, l1m, l2m, l3m, target_type, scheduler): - - is_train = not (optimizer is None) - - if is_train: - model.train() - else: - model.eval() - - #loss values for each batch: classification, regression - losses = np.zeros((len(loader), 3)) - - #accuracy values for each batch (monitor classification performance) - accuracies_batch = np.zeros(len(loader)) - - #correlation values for each batch (monitor regression performance) - corrs_batch = np.zeros(len(loader)) - - #epoch confusion matrix - conf_matrix = np.zeros((len(class_labels), len(class_labels))) - - #keep track of how many data points were processed - num_samples = 0 - for i, data in enumerate(loader): - t0 = time.time() - - if not multi_gpu: - data = data.to(device) - - if is_train: - optimizer.zero_grad() - - cand_id_onehot, cand_momentum, new_edge_index = model(data) - - nelem = data.x.shape[0] - assert(len(new_edge_index[0])>0) - mat1 = torch_geometric.utils.to_dense_adj(data.edge_index, max_num_nodes=nelem).flatten() - mat2 = torch_geometric.utils.to_dense_adj(new_edge_index, max_num_nodes=nelem).flatten() - - _dev = cand_id_onehot.device - _, indices = torch.max(cand_id_onehot, -1) - if not multi_gpu: - data = [data] - num_samples += len(cand_id_onehot) - - if args.target == "gen": - target_ids = torch.cat([d.y_gen_id for d in data]).to(_dev) - target_p4 = torch.cat([d.ygen[:, :4] for d in data]).to(_dev) - elif args.target == "cand": - target_ids = torch.cat([d.y_candidates_id for d in data]).to(_dev) - target_p4 = torch.cat([d.ycand[:, :4] for d in data]).to(_dev) - - #Predictions where both the predicted and true class label was nonzero - #In these cases, the true candidate existed and a candidate was predicted - msk = ((indices != 0) & (target_ids != 0)).detach().cpu() - msk2 = ((indices != 0) & (indices == target_ids)) - - accuracies_batch[i] = accuracy_score(target_ids[msk].detach().cpu().numpy(), indices[msk].detach().cpu().numpy()) - - weights = compute_weights(target_ids, _dev) - - #Loss for output candidate id (multiclass) - if l1m > 0.0: - l1 = l1m * torch.nn.functional.cross_entropy(cand_id_onehot, target_ids, weight=weights) - else: - l1 = torch.tensor(0.0).to(device=_dev) - - #Loss for candidate p4 properties (regression) - l2 = torch.tensor(0.0).to(device=_dev) - if l2m > 0.0: - l2 = l2m*torch.nn.functional.mse_loss(cand_momentum[msk2], target_p4[msk2]) - else: - l2 = torch.tensor(0.0).to(device=_dev) - - l3 = l3m*torch.nn.functional.binary_cross_entropy(mat2, mat1) - - batch_loss = l1 + l2 - losses[i, 0] = l1.item() - losses[i, 1] = l2.item() - losses[i, 2] = l3.item() - - if is_train: - batch_loss.backward() - - batch_loss_item = batch_loss.item() - t1 = time.time() - - print('{}/{} batch_loss={:.2f} dt={:.1f}s'.format(i, len(loader), batch_loss_item, t1-t0), end='\r', flush=True) - if is_train: - optimizer.step() - if not scheduler is None: - scheduler.step() - - #Compute correlation of predicted and true pt values for monitoring - corr_pt = 0.0 - if msk.sum()>0: - corr_pt = np.corrcoef( - cand_momentum[msk, 0].detach().cpu().numpy(), - target_p4[msk, 0].detach().cpu().numpy())[0,1] - - corrs_batch[i] = corr_pt - - conf_matrix += confusion_matrix(target_ids.detach().cpu().numpy(), - np.argmax(cand_id_onehot.detach().cpu().numpy(),axis=1), labels=range(8)) - - i += 1 - - corr = np.mean(corrs_batch) - acc = np.mean(accuracies_batch) - losses = losses.mean(axis=0) - return num_samples, losses, corr, acc, conf_matrix - -def train_loop(): - t0_initial = time.time() - - losses_train = np.zeros((args.n_epochs, 3)) - losses_val = np.zeros((args.n_epochs, 3)) - - corrs = [] - corrs_v = [] - accuracies = [] - accuracies_v = [] - best_val_loss = 99999.9 - stale_epochs = 0 - - print("Training over {} epochs".format(args.n_epochs)) - for j in range(args.n_epochs): - t0 = time.time() - - if stale_epochs > patience: - print("breaking due to stale epochs") - break - - with experiment.train(): - model.train() - num_samples_train, losses, c, acc, conf_matrix = train(model, train_loader, j, optimizer, - args.l1, args.l2, args.l3, args.target, scheduler) - experiment.log_metric('lr', optimizer.param_groups[0]['lr'], step=j) - l = sum(losses) - losses_train[j] = losses - corrs += [c] - accuracies += [acc] - experiment.log_metric('loss',l, step=j) - experiment.log_metric('loss1',losses[0], step=j) - experiment.log_metric('loss2',losses[1], step=j) - experiment.log_metric('loss3',losses[2], step=j) - experiment.log_metric('corrs',c, step=j) - experiment.log_metric('accuracy',acc, step=j) - experiment.log_confusion_matrix(matrix=conf_matrix, step=j, - title='Confusion Matrix Full', - file_name='confusion-matrix-full-train-%03d.json' % j, - labels = [str(c) for c in class_labels]) - - - with experiment.validate(): - model.eval() - num_samples_val, losses_v, c_v, acc_v, conf_matrix_v = test(model, val_loader, j, - args.l1, args.l2, args.l3, args.target) - l_v = sum(losses_v) - losses_val[j] = losses_v - corrs_v += [c_v] - accuracies_v += [acc_v] - experiment.log_metric('loss',l_v, step=j) - experiment.log_metric('loss1',losses_v[0], step=j) - experiment.log_metric('loss2',losses_v[1], step=j) - experiment.log_metric('loss3',losses_v[2], step=j) - experiment.log_metric('corrs',c_v, step=j) - experiment.log_metric('accuracy',acc_v, step=j) - experiment.log_confusion_matrix(matrix=conf_matrix_v, step=j, - title='Confusion Matrix Full', - file_name='confusion-matrix-full-val-%03d.json' % j, - labels = [str(c) for c in class_labels]) - - if l_v < best_val_loss: - best_val_loss = l_v - stale_epochs = 0 - else: - stale_epochs += 1 - - t1 = time.time() - epochs_remaining = args.n_epochs - j - time_per_epoch = (t1 - t0_initial)/(j + 1) - experiment.log_metric('time_per_epoch', time_per_epoch, step=j) - eta = epochs_remaining*time_per_epoch/60 - - spd = (num_samples_val+num_samples_train)/time_per_epoch - losses_str = "[" + ",".join(["{:.4f}".format(x) for x in losses_v]) + "]" - - torch.save(model.state_dict(), "{0}/epoch_{1}_weights.pth".format(outpath, j)) - - print("epoch={}/{} dt={:.2f}s l={:.5f}/{:.5f} c={:.2f}/{:.2f} a={:.6f}/{:.6f} partial_losses={} stale={} eta={:.1f}m spd={:.2f} samples/s lr={}".format( - j, args.n_epochs, - t1 - t0, l, l_v, c, c_v, acc, acc_v, - losses_str, stale_epochs, eta, spd, optimizer.param_groups[0]['lr'])) - -def parse_args(): - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--n_train", type=int, default=80, help="number of training events") - parser.add_argument("--n_val", type=int, default=20, help="number of validation events") - parser.add_argument("--n_epochs", type=int, default=100, help="number of training epochs") - parser.add_argument("--patience", type=int, default=100, help="patience before early stopping") - parser.add_argument("--hidden_dim", type=int, default=64, help="hidden dimension") - parser.add_argument("--encoding_dim", type=int, default=256, help="encoded element dimension") - parser.add_argument("--batch_size", type=int, default=1, help="Number of .pt files to load in parallel") - parser.add_argument("--model", type=str, choices=sorted(model_classes.keys()), help="type of model to use", default="PFNet6") - parser.add_argument("--target", type=str, choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", default="cand") - parser.add_argument("--dataset", type=str, help="Input dataset", required=True) - parser.add_argument("--outpath", type=str, default = 'data/', help="Output folder") - parser.add_argument("--activation", type=str, default='leaky_relu', choices=["selu", "leaky_relu", "relu"], help="activation function") - parser.add_argument("--optimizer", type=str, default='adam', choices=["adam", "adamw"], help="optimizer to use") - parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") - parser.add_argument("--l1", type=float, default=1.0, help="Loss multiplier for pdg-id classification") - parser.add_argument("--l2", type=float, default=1.0, help="Loss multiplier for momentum regression") - parser.add_argument("--l3", type=float, default=1.0, help="Loss multiplier for clustering") - parser.add_argument("--dropout", type=float, default=0.5, help="Dropout rate") - parser.add_argument("--radius", type=float, default=0.1, help="Radius-graph radius") - parser.add_argument("--convlayer", type=str, choices=["gravnet-knn", "gravnet-radius", "sgconv", "gatconv"], help="Convolutional layer", default="gravnet") - parser.add_argument("--convlayer2", type=str, choices=["sgconv", "graphunet", "gatconv", "none"], help="Convolutional layer", default="none") - parser.add_argument("--space_dim", type=int, default=2, help="Spatial dimension for clustering in gravnet layer") - parser.add_argument("--nearest", type=int, default=3, help="k nearest neighbors in gravnet layer") - parser.add_argument("--overwrite", action='store_true', help="overwrite if model output exists") - parser.add_argument("--disable_comet", action='store_true', help="disable comet-ml") - parser.add_argument("--input_encoding", type=int, help="use an input encoding layer", default=0) - parser.add_argument("--load", type=str, help="Load the weight file", required=False, default=None) - parser.add_argument("--scheduler", type=str, help="LR scheduler", required=False, default="none", choices=["none", "onecycle"]) - args = parser.parse_args() - return args - -if __name__ == "__main__": - - args = parse_args() - - # # the next part initializes some args values (to run the script not from terminal) - # class objectview(object): - # def __init__(self, d): - # self.__dict__ = d - # - # args = objectview({'n_train': 3, 'n_val': 2, 'n_epochs': 2, 'patience': 100, 'hidden_dim':64, 'encoding_dim': 256, - # 'batch_size': 1, 'model': 'PFNet7', 'target': 'cand', 'dataset': '../../test_tmp/data/TTbar_14TeV_TuneCUETP8M1_cfi', - # 'outpath': 'data/', 'activation': 'leaky_relu', 'optimizer': 'adam', 'lr': 1e-4, 'l1': 1, 'l2': 1, 'l3': 1, 'dropout': 0.5, - # 'radius': 0.1, 'convlayer': 'gravnet-radius', 'convlayer2': 'none', 'space_dim': 2, 'nearest': 3, 'overwrite': True, - # 'disable_comet': True, 'input_encoding': 0, 'load': None, 'scheduler': 'none'}) - - # define the dataset - full_dataset = PFGraphDataset(args.dataset) - - #one-hot encoded element ID + element parameters - input_dim = 26 - - #one-hot particle ID and momentum - output_dim_id = len(class_to_id) - output_dim_p4 = 4 - - edge_dim = 1 - - patience = args.patience - - train_dataset = torch.utils.data.Subset(full_dataset, np.arange(start=0, stop=args.n_train)) - val_dataset = torch.utils.data.Subset(full_dataset, np.arange(start=args.n_train, stop=args.n_train+args.n_val)) - print("train_dataset", len(train_dataset)) - print("val_dataset", len(val_dataset)) - - #hack for multi-gpu training - if not multi_gpu: - def collate(items): - l = sum(items, []) - return Batch.from_data_list(l) - else: - def collate(items): - l = sum(items, []) - return l - - train_loader = DataListLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=False) - train_loader.collate_fn = collate - val_loader = DataListLoader(val_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=False) - val_loader.collate_fn = collate - - model_class = model_classes[args.model] - model_kwargs = {'input_dim': input_dim, - 'hidden_dim': args.hidden_dim, - 'encoding_dim': args.encoding_dim, - 'output_dim_id': output_dim_id, - 'output_dim_p4': output_dim_p4, - 'dropout_rate': args.dropout, - 'convlayer': args.convlayer, - 'convlayer2': args.convlayer2, - 'radius': args.radius, - 'space_dim': args.space_dim, - 'activation': args.activation, - 'nearest': args.nearest, - 'input_encoding': args.input_encoding} - - - #instantiate the model - model = model_class(**model_kwargs) - if args.load: - s1 = torch.load(args.load, map_location=torch.device('cpu')) - s2 = {k.replace("module.", ""): v for k, v in s1.items()} - model.load_state_dict(s2) - - if multi_gpu: - model = torch_geometric.nn.DataParallel(model) - - model.to(device) - - model_fname = get_model_fname(args.dataset, model, args.n_train, args.lr, args.target) - - # need your api key in a .comet.config file: see https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables - experiment = Experiment(project_name="particleflow", disabled=args.disable_comet) - experiment.set_model_graph(repr(model)) - experiment.log_parameters(dict(model_kwargs, **{'model': args.model, 'lr':args.lr, 'model_fname': model_fname, - 'l1': args.l1, 'l2':args.l2, - 'n_train':args.n_train, 'target':args.target, 'optimizer': args.optimizer})) - outpath = osp.join(args.outpath, model_fname) - if osp.isdir(outpath): - if args.overwrite: - print("model output {} already exists, deleting it".format(outpath)) - import shutil - shutil.rmtree(outpath) - else: - print("model output {} already exists, please delete it".format(outpath)) - sys.exit(0) - try: - os.makedirs(outpath) - except Exception as e: - pass - - with open('{}/model_kwargs.pkl'.format(outpath), 'wb') as f: - pickle.dump(model_kwargs, f, protocol=pickle.HIGHEST_PROTOCOL) - - if args.optimizer == "adam": - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - elif args.optimizer == "adamw": - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) - - scheduler = None - if args.scheduler == "onecycle": - scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, - max_lr=args.lr, - steps_per_epoch=int(len(train_loader)), - epochs=args.n_epochs + 1, - anneal_strategy='linear', - ) - - loss = torch.nn.MSELoss() - loss2 = torch.nn.BCELoss() - - print(model) - print(model_fname) - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - params = sum([np.prod(p.size()) for p in model_parameters]) - print("params", params) - - model.train() - - train_loop() - # with torch.autograd.profiler.profile(use_cuda=True) as prof: - # train_loop() - - # print(prof.key_averages().table(sort_by="cuda_time_total")) diff --git a/mlpf/pytorch_delphes/README.md b/mlpf/pytorch_delphes/README.md deleted file mode 100644 index 3a25d6f4b..000000000 --- a/mlpf/pytorch_delphes/README.md +++ /dev/null @@ -1,23 +0,0 @@ -Short instructions to do a quick training on delphes data: -```bash -cd ../.. -./scripts/local_test_delphes_pytorch.sh -``` - -### Delphes dataset -The dataset is available from zenodo: https://doi.org/10.5281/zenodo.4452283. - -Instructions to download and process the full Delphes dataset: -```bash -cd ../../scripts/ -./get_all_data_delphes.sh -``` - -This script will download and process the data under a directory called "test_tmp_delphes/" in particleflow. There are will be two subdirectories under test_tmp_delphes/ (1) data/: which contains the data (2) experiments/: which will contain any trained model - - -Instructions to explain using LRP (you must have an already trained model in test_tmp_delphes/experiments): -```bash -cd LRP/ -python -u main_reg.py --LRP_load_model= --LRP_load_epoch= -``` diff --git a/mlpf/pytorch_delphes/__init__.py b/mlpf/pytorch_delphes/__init__.py deleted file mode 100644 index efc573206..000000000 --- a/mlpf/pytorch_delphes/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from pytorch_delphes.args import parse_args -from pytorch_delphes.graph_data_delphes import PFGraphDataset, one_hot_embedding -from pytorch_delphes.utils import dataloader_ttbar, dataloader_qcd -from pytorch_delphes.utils import get_model_fname, save_model, load_model -from pytorch_delphes.utils import make_plot, make_directories_for_plots - -from pytorch_delphes.model import MLPF - -from pytorch_delphes.training import training_loop -from pytorch_delphes.evaluate import make_predictions, make_plots diff --git a/mlpf/pytorch_delphes/model.py b/mlpf/pytorch_delphes/model.py deleted file mode 100644 index 3146eb722..000000000 --- a/mlpf/pytorch_delphes/model.py +++ /dev/null @@ -1,225 +0,0 @@ -import pickle as pkl -import os.path as osp -import os -import sys -from glob import glob - -import torch -from torch import Tensor -import torch.nn as nn -from torch.nn import Linear -from torch_scatter import scatter -from torch_geometric.nn.conv import MessagePassing -from torch_geometric.utils import to_dense_adj -import torch.nn.functional as F - -from typing import Optional, Union -from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor - -try: - from torch_cluster import knn -except ImportError: - knn = None -from torch_cluster import knn_graph - - -class MLPF(nn.Module): - """ - GNN model based on Gravnet... - Forward pass returns - preds: tensor of predictions containing a concatenated representation of the pids and p4 - target: dict() object containing gen and cand target information - """ - - def __init__(self, - input_dim=12, output_dim_id=6, output_dim_p4=6, - embedding_dim=64, hidden_dim1=64, hidden_dim2=60, - num_convs=2, space_dim=4, propagate_dim=30, k=8): - super(MLPF, self).__init__() - - # self.act = nn.ReLU - self.act = nn.ELU - - # (1) embedding - self.nn1 = nn.Sequential( - nn.Linear(input_dim, hidden_dim1), - self.act(), - nn.Linear(hidden_dim1, hidden_dim1), - self.act(), - nn.Linear(hidden_dim1, hidden_dim1), - self.act(), - nn.Linear(hidden_dim1, embedding_dim), - ) - - self.conv = nn.ModuleList() - for i in range(num_convs): - self.conv.append(GravNetConv_LRP(embedding_dim, embedding_dim, space_dim, propagate_dim, k)) - - # (3) DNN layer: classifiying pid - self.nn2 = nn.Sequential( - nn.Linear(input_dim + embedding_dim, hidden_dim2), - self.act(), - nn.Linear(hidden_dim2, hidden_dim2), - self.act(), - nn.Linear(hidden_dim2, hidden_dim2), - self.act(), - nn.Linear(hidden_dim2, output_dim_id), - ) - - # (4) DNN layer: regressing p4 - self.nn3 = nn.Sequential( - nn.Linear(input_dim + output_dim_id, hidden_dim2), - self.act(), - nn.Linear(hidden_dim2, hidden_dim2), - self.act(), - nn.Linear(hidden_dim2, hidden_dim2), - self.act(), - nn.Linear(hidden_dim2, output_dim_p4), - ) - - def forward(self, batch): - - # unfold the Batch object - input = batch.x - target = {'ygen_id': batch.ygen_id, - 'ygen': batch.ygen, - 'ycand_id': batch.ycand_id, - 'ycand': batch.ycand, - } - - # embed the inputs - embedding = self.nn1(input) - - # preform a series of graph convolutions - A = {} - msg_activations = {} - for num, conv in enumerate(self.conv): - embedding = conv(embedding) - - # predict the pid's - preds_id = self.nn2(torch.cat([input, embedding], axis=-1)) - - # predict the p4's - preds_p4 = self.nn3(torch.cat([input, preds_id], axis=-1)) - - return torch.cat([preds_id, preds_p4], axis=-1), target - - -class GravNetConv_LRP(MessagePassing): - """ - Copied from pytorch_geometric source code, with the following edits - - used reduce='sum' instead of reduce='mean' in the message passing - - removed skip connection - - The GravNet operator from the `"Learning Representations of Irregular - Particle-detector Geometry with Distance-weighted Graph - Networks" `_ paper, where the graph is - dynamically constructed using nearest neighbors. - The neighbors are constructed in a learnable low-dimensional projection of - the feature space. - A second projection of the input feature space is then propagated from the - neighbors to each vertex using distance weights that are derived by - applying a Gaussian function to the distances. - Args: - in_channels (int): Size of each input sample, or :obj:`-1` to derive - the size from the first input(s) to the forward method. - out_channels (int): The number of output channels. - space_dimensions (int): The dimensionality of the space used to - construct the neighbors; referred to as :math:`S` in the paper. - propagate_dimensions (int): The number of features to be propagated - between the vertices; referred to as :math:`F_{\textrm{LR}}` in the - paper. - k (int): The number of nearest neighbors. - num_workers (int): Number of workers to use for k-NN computation. - Has no effect in case :obj:`batch` is not :obj:`None`, or the input - lies on the GPU. (default: :obj:`1`) - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - Shapes: - - **input:** - node features :math:`(|\mathcal{V}|, F_{in})` or - :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` - if bipartite, - batch vector :math:`(|\mathcal{V}|)` or - :math:`((|\mathcal{V}_s|), (|\mathcal{V}_t|))` if bipartite - *(optional)* - - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or - :math:`(|\mathcal{V}_t|, F_{out})` if bipartite - """ - - def __init__(self, in_channels: int, out_channels: int, - space_dimensions: int, propagate_dimensions: int, k: int, - num_workers: int = 1, **kwargs): - super().__init__(flow='source_to_target', **kwargs) - - if knn is None: - raise ImportError('`GravNetConv` requires `torch-cluster`.') - - self.in_channels = in_channels - self.out_channels = out_channels - self.k = k - self.num_workers = num_workers - - self.lin_p = Linear(in_channels, propagate_dimensions) - self.lin_s = Linear(in_channels, space_dimensions) - self.lin_out = Linear(propagate_dimensions, out_channels) - - self.reset_parameters() - - def reset_parameters(self): - self.lin_s.reset_parameters() - self.lin_p.reset_parameters() - self.lin_out.reset_parameters() - - def forward( - self, x: Union[Tensor, PairTensor], - batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: - """""" - - is_bipartite: bool = True - if isinstance(x, Tensor): - x: PairTensor = (x, x) - is_bipartite = False - - if x[0].dim() != 2: - raise ValueError("Static graphs not supported in 'GravNetConv'") - - b: PairOptTensor = (None, None) - if isinstance(batch, Tensor): - b = (batch, batch) - elif isinstance(batch, tuple): - assert batch is not None - b = (batch[0], batch[1]) - - # embed the inputs before message passing - msg_activations = self.lin_p(x[0]) - - # transform to the space dimension to build the graph - s_l: Tensor = self.lin_s(x[0]) - s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l - - edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) - # edge_index = knn_graph(s_l, self.k, b[0], b[1]).flip([0]) - - edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) - edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread - - # message passing - out = self.propagate(edge_index, x=(msg_activations, None), - edge_weight=edge_weight, - size=(s_l.size(0), s_r.size(0))) - - return self.lin_out(out) - - def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: - return x_j * edge_weight.unsqueeze(1) - - def aggregate(self, inputs: Tensor, index: Tensor, - dim_size: Optional[int] = None) -> Tensor: - out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, - reduce='sum') - return out_mean - - def __repr__(self) -> str: - return (f'{self.__class__.__name__}({self.in_channels}, ' - f'{self.out_channels}, k={self.k})') diff --git a/mlpf/pytorch_delphes/training.py b/mlpf/pytorch_delphes/training.py deleted file mode 100644 index 8ee8a2c44..000000000 --- a/mlpf/pytorch_delphes/training.py +++ /dev/null @@ -1,220 +0,0 @@ -from pytorch_delphes import make_plot -from pytorch_delphes.utils_plots import plot_confusion_matrix - -import torch -import mplhep as hep -import matplotlib.pyplot as plt -import os -import pickle as pkl -import math -import time -import tqdm -import numpy as np -import pandas as pd -import sklearn -import sklearn.metrics -import matplotlib -matplotlib.use("Agg") - - -# Ignore divide by 0 errors -np.seterr(divide='ignore', invalid='ignore') - - -def compute_weights(device, target_ids, output_dim_id): - """ - computes necessary weights to accomodate class imbalance in the loss function - """ - vs, cs = torch.unique(target_ids, return_counts=True) - weights = torch.zeros(output_dim_id).to(device=device) - for k, v in zip(vs, cs): - weights[k] = 1.0 / math.sqrt(float(v)) - return weights - - -@torch.no_grad() -def validation_run(device, model, multi_gpu, loader, epoch, alpha, target_type, output_dim_id, outpath): - with torch.no_grad(): - optimizer = None - ret = train(device, model, multi_gpu, loader, epoch, optimizer, alpha, target_type, output_dim_id, outpath) - return ret - - -def train(device, model, multi_gpu, loader, epoch, optimizer, alpha, target_type, output_dim_id, outpath): - """ - a training block over a given epoch... - if optimizer is set to None, it freezes the model for a validation_run - """ - - is_train = not (optimizer is None) - - if is_train: - model.train() - else: - model.eval() - - # initialize placeholders for loss and accuracy - losses_clf, losses_reg, losses_tot, accuracies = [], [], [], [] - - # setup confusion matrix - conf_matrix = np.zeros((output_dim_id, output_dim_id)) - - # to compute average inference time - t = [] - - for i, batch in enumerate(loader): - - if multi_gpu: - X = batch # a list (not torch) instance so can't be passed to device - else: - X = batch.to(device) - - # run forward pass - t0 = time.time() - pred, target = model(X) - t1 = time.time() - t.append(t1 - t0) - - pred_ids_one_hot = pred[:, :6] - pred_p4 = pred[:, 6:] - - # define target - if target_type == 'gen': - target_ids_one_hot = target['ygen_id'] - target_p4 = target['ygen'] - elif target_type == 'cand': - target_ids_one_hot = target['ycand_id'] - target_p4 = target['ycand'] - - # revert one hot encoding - _, target_ids = torch.max(target_ids_one_hot, -1) - _, pred_ids = torch.max(pred_ids_one_hot, -1) - - # define some useful masks - msk = ((pred_ids != 0) & (target_ids != 0)) - msk2 = ((pred_ids != 0) & (pred_ids == target_ids)) - - # computing loss - weights = compute_weights(device, target_ids, output_dim_id) # to accomodate class imbalance - loss_clf = torch.nn.functional.cross_entropy(pred_ids_one_hot, target_ids, weight=weights) # for classifying PID - loss_reg = torch.nn.functional.mse_loss(pred_p4[msk2], target_p4[msk2]) # for regressing p4 - - loss_tot = loss_clf + alpha * loss_reg - - if is_train: - optimizer.zero_grad() - loss_tot.backward() - optimizer.step() - - losses_clf.append(loss_clf.detach().cpu().item()) - losses_reg.append(loss_reg.detach().cpu().item()) - losses_tot.append(loss_tot.detach().cpu().item()) - - accuracies.append(sklearn.metrics.accuracy_score(target_ids[msk].detach().cpu().numpy(), pred_ids[msk].detach().cpu().numpy())) - - conf_matrix += sklearn.metrics.confusion_matrix(target_ids.detach().cpu().numpy(), - pred_ids.detach().cpu().numpy(), - labels=range(6)) - - losses_clf = np.mean(losses_clf) - losses_reg = np.mean(losses_reg) - losses_tot = np.mean(losses_tot) - - accuracies = np.mean(accuracies) - - conf_matrix_norm = conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis] - - avg_inference_time = sum(t) / len(t) - print(f'Average inference time per event is {round(avg_inference_time,3)}s') - - return losses_clf, losses_reg, losses_tot, accuracies, conf_matrix_norm - - -def training_loop(device, model, multi_gpu, train_loader, valid_loader, n_epochs, patience, optimizer, alpha, target, output_dim_id, outpath): - """ - Main function for training a model - """ - - t0_initial = time.time() - - losses_clf_train, losses_reg_train, losses_tot_train = [], [], [] - losses_clf_valid, losses_reg_valid, losses_tot_valid = [], [], [] - - accuracies_train, accuracies_valid = [], [] - - best_val_loss = 99999.9 - stale_epochs = 0 - - print("Training over {} epochs".format(n_epochs)) - for epoch in range(n_epochs): - t0 = time.time() - - if stale_epochs > patience: - print("breaking due to stale epochs") - break - - # training step - model.train() - losses_clf, losses_reg, losses_tot, accuracies, conf_matrix_train = train(device, model, multi_gpu, train_loader, epoch, optimizer, alpha, target, output_dim_id, outpath) - - losses_clf_train.append(losses_clf) - losses_reg_train.append(losses_reg) - losses_tot_train.append(losses_tot) - - accuracies_train.append(accuracies) - - # validation step - model.eval() - losses_clf, losses_reg, losses_tot, accuracies, conf_matrix_val = validation_run(device, model, multi_gpu, valid_loader, epoch, alpha, target, output_dim_id, outpath) - - losses_clf_valid.append(losses_clf) - losses_reg_valid.append(losses_reg) - losses_tot_valid.append(losses_tot) - - accuracies_valid.append(accuracies) - - # early-stopping - if losses_tot < best_val_loss: - best_val_loss = losses_tot - stale_epochs = 0 - else: - stale_epochs += 1 - - t1 = time.time() - - epochs_remaining = n_epochs - (epoch + 1) - time_per_epoch = (t1 - t0_initial) / (epoch + 1) - eta = epochs_remaining * time_per_epoch / 60 - - print(f"epoch={epoch + 1} / {n_epochs} train_loss={round(losses_tot_train[epoch], 4)} valid_loss={round(losses_tot_valid[epoch], 4)} train_acc={round(accuracies_train[epoch], 4)} valid_acc={round(accuracies_valid[epoch], 4)} stale={stale_epochs} eta={round(eta, 1)}m") - - # save the model's weights - torch.save(model.state_dict(), f'{outpath}/epoch_{epoch}_weights.pth') - - # create directory to hold training plots - if not os.path.exists(outpath + '/training_plots/'): - os.makedirs(outpath + '/training_plots/') - - # make confusion matrix plots - cm_path = outpath + '/training_plots/confusion_matrix_plots/' - if not os.path.exists(cm_path): - os.makedirs(cm_path) - target_names = ["none", "ch.had", "n.had", "g", "el", "mu"] - plot_confusion_matrix(conf_matrix_train, target_names, epoch, cm_path, f'cmT_epoch_{str(epoch)}') - plot_confusion_matrix(conf_matrix_val, target_names, epoch, cm_path, f'cmV_epoch_{str(epoch)}') - - # make loss plots - make_plot(losses_clf_train, 'train loss_clf', 'Epochs', 'Loss', outpath + '/training_plots/losses/', 'losses_clf_train') - make_plot(losses_reg_train, 'train loss_reg', 'Epochs', 'Loss', outpath + '/training_plots/losses/', 'losses_reg_train') - make_plot(losses_tot_train, 'train loss_tot', 'Epochs', 'Loss', outpath + '/training_plots/losses/', 'losses_tot_train') - - make_plot(losses_clf_valid, 'valid loss_clf', 'Epochs', 'Loss', outpath + '/training_plots/losses/', 'losses_clf_valid') - make_plot(losses_reg_valid, 'valid loss_reg', 'Epochs', 'Loss', outpath + '/training_plots/losses/', 'losses_reg_valid') - make_plot(losses_tot_valid, 'valid loss_tot', 'Epochs', 'Loss', outpath + '/training_plots/losses/', 'losses_tot_valid') - - # make accuracy plots - make_plot(accuracies_train, 'train accuracy', 'Epochs', 'Accuracy', outpath + '/training_plots/accuracies/', 'accuracies_train') - make_plot(accuracies_valid, 'valid accuracy', 'Epochs', 'Accuracy', outpath + '/training_plots/accuracies/', 'accuracies_valid') - - print('Done with training.') - return diff --git a/mlpf/pytorch_delphes/utils.py b/mlpf/pytorch_delphes/utils.py deleted file mode 100644 index a09a2affa..000000000 --- a/mlpf/pytorch_delphes/utils.py +++ /dev/null @@ -1,185 +0,0 @@ -import json -import shutil -import os.path as osp -import sys -from glob import glob -import torch_geometric -from torch_geometric.loader import DataLoader, DataListLoader -from torch_geometric.data import Data, Batch -import torch -import mplhep as hep -import matplotlib.pyplot as plt -import os -import pickle as pkl -import math -import time -import tqdm -import numpy as np -import pandas as pd -import sklearn -import matplotlib -matplotlib.use("Agg") - - -def dataloader_ttbar(full_dataset, multi_gpu, n_train, n_valid, batch_size): - """ - Builds training and validation dataloaders from a physics dataset for conveninet ML training - - Args: - full_dataset: a delphes dataset that is a list of lists that contain Data() objects - multi_gpu: boolean for multigpu batching - n_train: number of files to use for training - n_valid: number of files to use for validation - - Returns: - train_loader: a pyg iterable DataLoader() that contains Batch objects for training - valid_loader: a pyg iterable DataLoader() that contains Batch objects for validation - """ - - train_dataset = torch.utils.data.Subset(full_dataset, np.arange(start=0, stop=n_train)) - valid_dataset = torch.utils.data.Subset(full_dataset, np.arange(start=n_train, stop=n_train + n_valid)) - - # preprocessing the train_dataset in a good format for passing correct batches of events to the GNN - train_data = [] - for i in range(len(train_dataset)): - train_data = train_data + train_dataset[i] - - # preprocessing the valid_dataset in a good format for passing correct batches of events to the GNN - valid_data = [] - for i in range(len(valid_dataset)): - valid_data = valid_data + valid_dataset[i] - - if not multi_gpu: - train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) - valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True) - else: - train_loader = DataListLoader(train_data, batch_size=batch_size, shuffle=True) - valid_loader = DataListLoader(valid_data, batch_size=batch_size, shuffle=True) - - return train_loader, valid_loader - - -def dataloader_qcd(full_dataset, multi_gpu, n_test, batch_size): - """ - Builds a testing dataloader from a physics dataset for conveninet ML training - - Args: - full_dataset: a delphes dataset that is a list of lists that contain Data() objects - multi_gpu: boolean for multigpu batching - n_test: number of files to use for testing - - Returns: - test_loader: a pyg iterable DataLoader() that contains Batch objects for testing - """ - - test_dataset = torch.utils.data.Subset(full_dataset, np.arange(start=0, stop=n_test)) - - # preprocessing the test_dataset in a good format for passing correct batches of events to the GNN - test_data = [] - for i in range(len(test_dataset)): - test_data = test_data + test_dataset[i] - - if not multi_gpu: - test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) - else: - test_loader = DataListLoader(test_data, batch_size=batch_size, shuffle=True) - - return test_loader - - -def get_model_fname(dataset, model, n_train, n_epochs, target, alpha, title): - """ - Get a unique directory name for the model - """ - if alpha == 0: - task = "clf" - else: - task = "clf_reg" - - model_name = type(model).__name__ - model_fname = '{}_{}_ntrain_{}_nepochs_{}_{}'.format( - model_name, - target, - n_train, - n_epochs, - task) - - if title: - model_fname = model_fname + '_' + title - - return model_fname - - -def save_model(args, model_fname, outpath, model_kwargs): - if osp.isdir(outpath): - print(args.load_model) - if args.overwrite: - print("model {} already exists, deleting it".format(model_fname)) - shutil.rmtree(outpath) - else: - print("model {} already exists, please delete it".format(model_fname)) - sys.exit(0) - os.makedirs(outpath) - - with open(f'{outpath}/model_kwargs.pkl', 'wb') as f: # dump model architecture - pkl.dump(model_kwargs, f, protocol=pkl.HIGHEST_PROTOCOL) - - with open(f'{outpath}/hyperparameters.json', 'w') as fp: # dump hyperparameters - json.dump({'lr': args.lr, 'batch_size': args.batch_size, 'alpha': args.alpha, 'nearest': args.nearest}, fp) - - -def load_model(device, outpath, model_directory, load_epoch): - PATH = outpath + '/epoch_' + str(load_epoch) + '_weights.pth' - - print('Loading a previously trained model..') - with open(outpath + '/model_kwargs.pkl', 'rb') as f: - model_kwargs = pkl.load(f) - - state_dict = torch.load(PATH, map_location=device) - - if "DataParallel" in model_directory: # if the model was trained using DataParallel then we do this - state_dict = torch.load(PATH, map_location=device) - from collections import OrderedDict - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] # remove module. - new_state_dict[name] = v - state_dict = new_state_dict - - return state_dict, model_kwargs, outpath - - -def make_plot(X, label, xlabel, ylabel, outpath, save_as): - """ - Given a list X, makes a scatter plot of it and saves it - """ - plt.style.use(hep.style.ROOT) - - if not os.path.exists(outpath): - os.makedirs(outpath) - - fig, ax = plt.subplots() - ax.plot(range(len(X)), X, label=label) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.legend(loc='best') - plt.savefig(outpath + save_as + '.png') - plt.close(fig) - - with open(outpath + save_as + '.pkl', 'wb') as f: - pkl.dump(X, f) - - -def make_directories_for_plots(outpath, tag): - if not osp.isdir(f'{outpath}/{tag}_plots'): - os.makedirs(f'{outpath}/{tag}_plots') - if not osp.isdir(f'{outpath}/{tag}_plots/resolution_plots'): - os.makedirs(f'{outpath}/{tag}_plots/resolution_plots') - if not osp.isdir(f'{outpath}/{tag}_plots/distribution_plots'): - os.makedirs(f'{outpath}/{tag}_plots/distribution_plots') - if not osp.isdir(f'{outpath}/{tag}_plots/multiplicity_plots'): - os.makedirs(f'{outpath}/{tag}_plots/multiplicity_plots') - if not osp.isdir(f'{outpath}/{tag}_plots/efficiency_plots'): - os.makedirs(f'{outpath}/{tag}_plots/efficiency_plots') - if not osp.isdir(f'{outpath}/{tag}_plots/confusion_matrix_plots'): - os.makedirs(f'{outpath}/{tag}_plots/confusion_matrix_plots') diff --git a/mlpf/pytorch_pipeline.py b/mlpf/pytorch_pipeline.py deleted file mode 100644 index e1c39f5fc..000000000 --- a/mlpf/pytorch_pipeline.py +++ /dev/null @@ -1,149 +0,0 @@ -from pytorch_delphes import parse_args -from pytorch_delphes import PFGraphDataset, dataloader_ttbar, dataloader_qcd -from pytorch_delphes import MLPF, training_loop, make_predictions, make_plots -from pytorch_delphes import make_directories_for_plots -from pytorch_delphes import get_model_fname, save_model, load_model - -import torch -import torch_geometric - -import mplhep as hep -import matplotlib.pyplot as plt -from glob import glob -import sys -import os -import os.path as osp -import shutil -import pickle as pkl -import json -import math -import time -import tqdm -import numpy as np -import pandas as pd -import sklearn -import matplotlib - -matplotlib.use("Agg") - -# Ignore divide by 0 errors -np.seterr(divide='ignore', invalid='ignore') - -# Check if the GPU configuration has been provided -use_gpu = torch.cuda.device_count() > 0 -multi_gpu = torch.cuda.device_count() > 1 - -if multi_gpu or use_gpu: - print(f'Will use {torch.cuda.device_count()} gpu(s)') -else: - print('Will use cpu') - -# define the global base device -if use_gpu: - device = torch.device('cuda:0') - print("GPU model:", torch.cuda.get_device_name(0)) -else: - device = torch.device('cpu') - - -if __name__ == "__main__": - - """ - e.g. to train locally run as: - python -u pytorch_pipeline.py --title='' --overwrite --target='gen' --n_epochs=1 --n_train=1 --n_valid=1 --n_test=1 --batch_size=1 - - e.g. to load and evaluate run as: - python -u pytorch_pipeline.py --load --load_model='MLPF_gen_ntrain_1_nepochs_1_clf_reg' --load_epoch=0 --target='gen' --n_test=1 --batch_size=2 - """ - - args = parse_args() - - # load the dataset (assumes the data exists as .pt files under args.dataset/processed) - print('Loading the data..') - full_dataset_ttbar = PFGraphDataset(args.dataset) - full_dataset_qcd = PFGraphDataset(args.dataset_qcd) - - # construct Dataloaders to facilitate looping over batches - print('Building dataloaders..') - train_loader, valid_loader = dataloader_ttbar(full_dataset_ttbar, multi_gpu, args.n_train, args.n_valid, batch_size=args.batch_size) - test_loader = dataloader_qcd(full_dataset_qcd, multi_gpu, args.n_test, batch_size=args.batch_size) - - # PF-elements - input_dim = 12 - - # PF-candidates - output_dim_id = 6 - output_dim_p4 = 6 - - if args.load: - outpath = args.outpath + args.load_model - state_dict, model_kwargs, outpath = load_model(device, outpath, args.load_model, args.load_epoch) - - model = MLPF(**model_kwargs) - model.load_state_dict(state_dict) - - if multi_gpu: - model = torch_geometric.nn.DataParallel(model) - - model.to(device) - - else: - print('Instantiating a model..') - model_kwargs = {'input_dim': input_dim, - 'output_dim_id': output_dim_id, - 'output_dim_p4': output_dim_p4, - 'embedding_dim': args.embedding_dim, - 'hidden_dim1': args.hidden_dim1, - 'hidden_dim2': args.hidden_dim2, - 'num_convs': args.num_convs, - 'space_dim': args.space_dim, - 'propagate_dim': args.propagate_dim, - 'k': args.nearest, - } - - model = MLPF(**model_kwargs) - - # get a directory name for the model to store the model's weights and plots - model_fname = get_model_fname(args.dataset, model, args.n_train, args.n_epochs, args.target, args.alpha, args.title) - outpath = osp.join(args.outpath, model_fname) - - if multi_gpu: - print("Parallelizing the training..") - model = torch_geometric.nn.DataParallel(model) - - model.to(device) - - save_model(args, model_fname, outpath, model_kwargs) - - print(model) - print(model_fname) - - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - - model.train() - training_loop(device, model, multi_gpu, - train_loader, valid_loader, - args.n_epochs, args.patience, - optimizer, args.alpha, args.target, - output_dim_id, outpath) - - model.eval() - - # evaluate on testing data.. - make_directories_for_plots(outpath, 'test_data') - if args.load: - make_predictions(model, multi_gpu, test_loader, outpath + '/test_data_plots/', device, args.load_epoch) - make_plots(model, test_loader, outpath + '/test_data_plots/', args.target, device, args.load_epoch, 'QCD') - else: - make_predictions(model, multi_gpu, test_loader, outpath + '/test_data_plots/', device, args.n_epochs) - make_plots(model, test_loader, outpath + '/test_data_plots/', args.target, device, args.n_epochs, 'QCD') - - # # evaluate on training data.. - # make_directories_for_plots(outpath, 'train_data') - # make_predictions(model, multi_gpu, train_loader, outpath + '/train_data_plots', args.target, device, args.n_epochs) - # make_plots(model, train_loader, outpath + '/train_data_plots', args.target, device, args.n_epochs, 'TTbar') - # - # # evaluate on validation data.. - # make_directories_for_plots(outpath, 'valid_data') - # make_predictions(model, multi_gpu, valid_loader, outpath + '/valid_data_plots', args.target, device, args.n_epochs) - # make_plots(model, valid_loader, outpath + '/valid_data_plots', args.target, device, args.n_epochs, 'TTbar') diff --git a/notebooks/cms-mlpf.ipynb b/notebooks/cms-mlpf.ipynb index 9b7481166..52410ed76 100644 --- a/notebooks/cms-mlpf.ipynb +++ b/notebooks/cms-mlpf.ipynb @@ -3,7 +3,6 @@ { "cell_type": "code", "execution_count": null, - "id": "b79c90e5", "metadata": {}, "outputs": [], "source": [ @@ -13,7 +12,6 @@ { "cell_type": "code", "execution_count": null, - "id": "impressive-ethiopia", "metadata": {}, "outputs": [], "source": [ @@ -39,7 +37,6 @@ { "cell_type": "code", "execution_count": null, - "id": "statistical-ordering", "metadata": {}, "outputs": [], "source": [ @@ -51,7 +48,6 @@ { "cell_type": "code", "execution_count": null, - "id": "visible-destruction", "metadata": {}, "outputs": [], "source": [ @@ -73,7 +69,6 @@ { "cell_type": "code", "execution_count": null, - "id": "undefined-judges", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +100,6 @@ { "cell_type": "code", "execution_count": null, - "id": "respective-theater", "metadata": {}, "outputs": [], "source": [ @@ -123,70 +117,98 @@ { "cell_type": "code", "execution_count": null, - "id": "026e4082", "metadata": {}, "outputs": [], "source": [ "!ls -lrt ../experiments/" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load the predictions" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "stone-spanking", "metadata": {}, "outputs": [], "source": [ - "path = \"../experiments/cms-gen_20220509_152833_824459.gpu0.local/evaluation/\"" + "backend = 'pyg'" ] }, { "cell_type": "code", "execution_count": null, - "id": "blind-promotion", "metadata": {}, "outputs": [], "source": [ - "Xs = []\n", - "yvals = {}\n", - "for fi in list(glob.glob(path + \"/pred_batch*.npz\")):\n", - " dd = np.load(fi)\n", - " Xs.append(dd[\"X\"])\n", - " \n", - " keys_in_file = list(dd.keys())\n", - " for k in keys_in_file:\n", - " if k==\"X\":\n", - " continue\n", - " if not (k in yvals):\n", - " yvals[k] = []\n", - " yvals[k].append(dd[k])\n", - "\n", - "X = np.concatenate(Xs)\n", - "X_f = flatten(X)\n", - "\n", - "msk_X_f = X_f[:, 0] != 0\n", - "\n", - "yvals = {k: np.concatenate(v) for k, v in yvals.items()}\n", - "\n", - "for val in [\"gen\", \"cand\", \"pred\"]:\n", - " yvals[\"{}_phi\".format(val)] = np.arctan2(yvals[\"{}_sin_phi\".format(val)], yvals[\"{}_cos_phi\".format(val)])\n", - " yvals[\"{}_cls_id\".format(val)] = np.expand_dims(np.argmax(yvals[\"{}_cls\".format(val)], axis=-1), axis=-1)\n", - "\n", - " yvals[\"{}_px\".format(val)] = np.sin(yvals[\"{}_phi\".format(val)])*yvals[\"{}_pt\".format(val)]\n", - " yvals[\"{}_py\".format(val)] = np.cos(yvals[\"{}_phi\".format(val)])*yvals[\"{}_pt\".format(val)]\n", - " \n", - "yvals_f = {k: flatten(v) for k, v in yvals.items()}\n", - "\n", - "#remove the last dim\n", - "for k in yvals_f.keys():\n", - " if yvals_f[k].shape[-1] == 1:\n", - " yvals_f[k] = yvals_f[k][..., -1]" + "if backend == 'tf':\n", + " path = \"../experiments/cms-gen_20220509_152833_824459.gpu0.local/evaluation/\" \n", + " Xs = []\n", + " yvals = {}\n", + " for fi in list(glob.glob(path + \"/pred_batch*.npz\")):\n", + " dd = np.load(fi)\n", + " Xs.append(dd[\"X\"])\n", + "\n", + " keys_in_file = list(dd.keys())\n", + " for k in keys_in_file:\n", + " if k==\"X\":\n", + " continue\n", + " if not (k in yvals):\n", + " yvals[k] = []\n", + " yvals[k].append(dd[k])\n", + "\n", + " X = np.concatenate(Xs)\n", + " X_f = flatten(X)\n", + "\n", + " msk_X_f = X_f[:, 0] != 0\n", + "\n", + " yvals = {k: np.concatenate(v) for k, v in yvals.items()}\n", + "\n", + " for val in [\"gen\", \"cand\", \"pred\"]:\n", + " yvals[\"{}_phi\".format(val)] = np.arctan2(yvals[\"{}_sin_phi\".format(val)], yvals[\"{}_cos_phi\".format(val)])\n", + " yvals[\"{}_cls_id\".format(val)] = np.expand_dims(np.argmax(yvals[\"{}_cls\".format(val)], axis=-1), axis=-1)\n", + "\n", + " yvals[\"{}_px\".format(val)] = np.sin(yvals[\"{}_phi\".format(val)])*yvals[\"{}_pt\".format(val)]\n", + " yvals[\"{}_py\".format(val)] = np.cos(yvals[\"{}_phi\".format(val)])*yvals[\"{}_pt\".format(val)]\n", + "\n", + " yvals_f = {k: flatten(v) for k, v in yvals.items()}\n", + "\n", + " #remove the last dim\n", + " for k in yvals_f.keys():\n", + " if yvals_f[k].shape[-1] == 1:\n", + " yvals_f[k] = yvals_f[k][..., -1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if backend == 'pyg':\n", + " import torch\n", + " path = './preds/'\n", + " X = torch.load(f'{path}/post_processed_Xs.pt')\n", + " X_f = torch.load(f'{path}/post_processed_X_f.pt')\n", + " msk_X_f = torch.load(f'{path}/post_processed_msk_X_f.pt')\n", + " yvals = torch.load(f'{path}/post_processed_yvals.pt')\n", + " yvals_f = torch.load(f'{path}/post_processed_yvals_f.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Make plots" ] }, { "cell_type": "code", "execution_count": null, - "id": "4de09104", "metadata": {}, "outputs": [], "source": [ @@ -203,7 +225,6 @@ { "cell_type": "code", "execution_count": null, - "id": "b4ff2631", "metadata": {}, "outputs": [], "source": [ @@ -223,7 +244,6 @@ { "cell_type": "code", "execution_count": null, - "id": "3d4912ca", "metadata": {}, "outputs": [], "source": [ @@ -247,7 +267,6 @@ { "cell_type": "code", "execution_count": null, - "id": "b5056249", "metadata": {}, "outputs": [], "source": [ @@ -280,15 +299,6 @@ { "cell_type": "code", "execution_count": null, - "id": "7fda3560", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21217680", "metadata": {}, "outputs": [], "source": [ @@ -321,7 +331,6 @@ { "cell_type": "code", "execution_count": null, - "id": "0141c01a", "metadata": {}, "outputs": [], "source": [ @@ -355,7 +364,6 @@ { "cell_type": "code", "execution_count": null, - "id": "46cfb473", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +397,6 @@ { "cell_type": "code", "execution_count": null, - "id": "dacf1c4b", "metadata": {}, "outputs": [], "source": [ @@ -423,7 +430,6 @@ { "cell_type": "code", "execution_count": null, - "id": "c04c999b", "metadata": {}, "outputs": [], "source": [ @@ -457,7 +463,6 @@ { "cell_type": "code", "execution_count": null, - "id": "cf045b8d", "metadata": {}, "outputs": [], "source": [ @@ -491,7 +496,6 @@ { "cell_type": "code", "execution_count": null, - "id": "4348f4d5", "metadata": {}, "outputs": [], "source": [ @@ -525,7 +529,6 @@ { "cell_type": "code", "execution_count": null, - "id": "ef355d9b", "metadata": {}, "outputs": [], "source": [ @@ -559,7 +562,6 @@ { "cell_type": "code", "execution_count": null, - "id": "e8ef722e", "metadata": {}, "outputs": [], "source": [ @@ -593,7 +595,6 @@ { "cell_type": "code", "execution_count": null, - "id": "750078b2", "metadata": {}, "outputs": [], "source": [ @@ -627,7 +628,6 @@ { "cell_type": "code", "execution_count": null, - "id": "9ddfc59e", "metadata": {}, "outputs": [], "source": [ @@ -661,7 +661,6 @@ { "cell_type": "code", "execution_count": null, - "id": "d5b04426", "metadata": {}, "outputs": [], "source": [ @@ -690,7 +689,6 @@ { "cell_type": "code", "execution_count": null, - "id": "d41ecf82", "metadata": {}, "outputs": [], "source": [ @@ -718,7 +716,6 @@ { "cell_type": "code", "execution_count": null, - "id": "03a0dd55", "metadata": {}, "outputs": [], "source": [ @@ -745,7 +742,6 @@ { "cell_type": "code", "execution_count": null, - "id": "virgin-nicaragua", "metadata": { "scrolled": false }, @@ -799,7 +795,6 @@ { "cell_type": "code", "execution_count": null, - "id": "funny-batch", "metadata": {}, "outputs": [], "source": [ @@ -822,7 +817,6 @@ { "cell_type": "code", "execution_count": null, - "id": "micro-saying", "metadata": {}, "outputs": [], "source": [ @@ -900,7 +894,6 @@ { "cell_type": "code", "execution_count": null, - "id": "automated-quarter", "metadata": { "scrolled": false }, @@ -912,7 +905,6 @@ { "cell_type": "code", "execution_count": null, - "id": "83fb675e", "metadata": { "scrolled": false }, @@ -924,7 +916,6 @@ { "cell_type": "code", "execution_count": null, - "id": "military-professor", "metadata": { "scrolled": false }, @@ -936,7 +927,6 @@ { "cell_type": "code", "execution_count": null, - "id": "5ec54d38", "metadata": { "scrolled": false }, @@ -948,7 +938,6 @@ { "cell_type": "code", "execution_count": null, - "id": "810d7935", "metadata": { "scrolled": false }, @@ -960,7 +949,6 @@ { "cell_type": "code", "execution_count": null, - "id": "258c5cb7", "metadata": { "scrolled": false }, @@ -972,7 +960,6 @@ { "cell_type": "code", "execution_count": null, - "id": "c20045e4", "metadata": {}, "outputs": [], "source": [ @@ -986,7 +973,6 @@ { "cell_type": "code", "execution_count": null, - "id": "characteristic-colleague", "metadata": {}, "outputs": [], "source": [ @@ -1000,7 +986,6 @@ { "cell_type": "code", "execution_count": null, - "id": "ready-macedonia", "metadata": {}, "outputs": [], "source": [ @@ -1022,7 +1007,6 @@ { "cell_type": "code", "execution_count": null, - "id": "formal-county", "metadata": {}, "outputs": [], "source": [ @@ -1032,7 +1016,6 @@ { "cell_type": "code", "execution_count": null, - "id": "neural-witch", "metadata": {}, "outputs": [], "source": [ @@ -1061,7 +1044,6 @@ { "cell_type": "code", "execution_count": null, - "id": "formal-maryland", "metadata": {}, "outputs": [], "source": [ @@ -1074,7 +1056,6 @@ { "cell_type": "code", "execution_count": null, - "id": "committed-clothing", "metadata": {}, "outputs": [], "source": [ @@ -1087,7 +1068,6 @@ { "cell_type": "code", "execution_count": null, - "id": "recreational-enhancement", "metadata": {}, "outputs": [], "source": [ @@ -1102,7 +1082,6 @@ { "cell_type": "code", "execution_count": null, - "id": "empirical-network", "metadata": {}, "outputs": [], "source": [ @@ -1141,7 +1120,6 @@ { "cell_type": "code", "execution_count": null, - "id": "a79af186", "metadata": {}, "outputs": [], "source": [ @@ -1180,7 +1158,6 @@ { "cell_type": "code", "execution_count": null, - "id": "415cf286", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1177,6 @@ { "cell_type": "code", "execution_count": null, - "id": "9117de9c", "metadata": {}, "outputs": [], "source": [ @@ -1219,7 +1195,6 @@ { "cell_type": "code", "execution_count": null, - "id": "272b67b8", "metadata": {}, "outputs": [], "source": [ @@ -1238,7 +1213,6 @@ { "cell_type": "code", "execution_count": null, - "id": "expressed-samba", "metadata": { "scrolled": false }, @@ -1296,7 +1270,6 @@ { "cell_type": "code", "execution_count": null, - "id": "paperback-timeline", "metadata": {}, "outputs": [], "source": [ @@ -1340,7 +1313,6 @@ { "cell_type": "code", "execution_count": null, - "id": "ecological-toner", "metadata": {}, "outputs": [], "source": [ @@ -1354,7 +1326,6 @@ { "cell_type": "code", "execution_count": null, - "id": "transparent-remedy", "metadata": {}, "outputs": [], "source": [ @@ -1368,7 +1339,6 @@ { "cell_type": "code", "execution_count": null, - "id": "promotional-checklist", "metadata": {}, "outputs": [], "source": [ @@ -1382,7 +1352,6 @@ { "cell_type": "code", "execution_count": null, - "id": "suitable-kansas", "metadata": {}, "outputs": [], "source": [ @@ -1396,7 +1365,6 @@ { "cell_type": "code", "execution_count": null, - "id": "restricted-million", "metadata": {}, "outputs": [], "source": [ @@ -1410,7 +1378,6 @@ { "cell_type": "code", "execution_count": null, - "id": "raising-first", "metadata": {}, "outputs": [], "source": [ @@ -1424,7 +1391,6 @@ { "cell_type": "code", "execution_count": null, - "id": "eb29cd12", "metadata": {}, "outputs": [], "source": [ @@ -1438,7 +1404,6 @@ { "cell_type": "code", "execution_count": null, - "id": "56fc8943", "metadata": {}, "outputs": [], "source": [ @@ -1452,7 +1417,6 @@ { "cell_type": "code", "execution_count": null, - "id": "51617bae", "metadata": {}, "outputs": [], "source": [ @@ -1480,7 +1444,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.3" } }, "nbformat": 4, diff --git a/notebooks/delphes-lrp-playground.ipynb b/notebooks/delphes-lrp-playground.ipynb new file mode 100644 index 000000000..339665516 --- /dev/null +++ b/notebooks/delphes-lrp-playground.ipynb @@ -0,0 +1,949 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This quickstart notebook allows to test and mess around with the MLPF GNN model in a standalone way. For actual training, we don't use a notebook, please refer to `README.md`.\n", + "\n", + "\n", + "```bash\n", + "git clone https://github.com/jpata/particleflow/\n", + "```\n", + "\n", + "Run the notebook from `notebooks/delphes-tf-mlpf-quickstart.ipynb`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import bz2, pickle\n", + "import numpy as np\n", + "# import tensorflow as tf\n", + "import sklearn\n", + "import sklearn.metrics\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import pickle as pkl\n", + "import os.path as osp\n", + "import os\n", + "import sys\n", + "from glob import glob\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from typing import Optional, Union" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path += [\"../mlpf\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tfmodel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget --no-check-certificate -nc https://zenodo.org/record/4452283/files/tev14_pythia8_ttbar_0_0.pkl.bz2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pickle.load(bz2.BZ2File(\"tev14_pythia8_ttbar_0_0.pkl.bz2\", \"r\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#100 events in one file\n", + "len(data[\"X\"]), len(data[\"ygen\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Pad the number of elements to a size that's divisible by the bin size\n", + "Xs = []\n", + "ys = []\n", + "\n", + "max_size = 50*128\n", + "for i in range(len(data[\"X\"])):\n", + " X = data[\"X\"][i][:max_size, :]\n", + " y = data[\"ygen\"][i][:max_size, :]\n", + " Xpad = np.pad(X, [(0, max_size - X.shape[0]), (0, 0)])\n", + " ypad = np.pad(y, [(0, max_size - y.shape[0]), (0, 0)])\n", + " Xpad = Xpad.astype(np.float32)\n", + " ypad = ypad.astype(np.float32)\n", + " Xs.append(Xpad)\n", + " ys.append(ypad)\n", + " \n", + "X = np.stack(Xs)\n", + "y = np.stack(ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test the pytorch setup for the input X" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# defines a pytorch FCN class for particleflow \n", + "\n", + "class MLPF_FCN(nn.Module):\n", + " \"\"\"\n", + " Showcase an example of an fully connected network pytorch model, with a skip connection, that can be explained by LRP\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim=12, hidden_dim=2, embedding_dim=2, output_dim=2):\n", + " super(MLPF_FCN, self).__init__()\n", + "\n", + " self.act = nn.ReLU\n", + "\n", + " self.nn1 = nn.Sequential(\n", + " nn.Linear(input_dim, hidden_dim),\n", + " nn.BatchNorm1d(hidden_dim),\n", + " self.act(),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " nn.BatchNorm1d(hidden_dim),\n", + " self.act(),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " nn.BatchNorm1d(hidden_dim),\n", + " self.act(),\n", + " nn.Linear(hidden_dim, embedding_dim),\n", + " )\n", + " self.nn2 = nn.Sequential(\n", + " nn.Linear(input_dim + embedding_dim, hidden_dim),\n", + " self.act(),\n", + " nn.Linear(hidden_dim, output_dim),\n", + " )\n", + "\n", + " def forward(self, X):\n", + " embedding = self.nn1(X)\n", + " return self.nn2(torch.cat([X, embedding], axis=1)), _, _\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# defines a pytorch GNN class for particleflow \n", + "\n", + "import pickle as pkl\n", + "import os.path as osp\n", + "import os\n", + "import sys\n", + "from glob import glob\n", + "\n", + "import torch\n", + "from torch import Tensor\n", + "import torch.nn as nn\n", + "from torch.nn import Linear\n", + "from torch_scatter import scatter\n", + "from torch_geometric.nn.conv import MessagePassing\n", + "from torch_geometric.utils import to_dense_adj\n", + "import torch.nn.functional as F\n", + "\n", + "from typing import Optional, Union\n", + "from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor\n", + "from torch_geometric.data import Data, DataLoader, DataListLoader, Batch\n", + "\n", + "try:\n", + " from torch_cluster import knn\n", + "except ImportError:\n", + " knn = None\n", + "from torch_cluster import knn_graph\n", + "\n", + "import numpy as np\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "class MLPF_GNN(nn.Module):\n", + " \"\"\"\n", + " GNN model based on Gravnet...\n", + "\n", + " Forward pass returns\n", + " preds: tensor of predictions containing a concatenated representation of the pids and p4\n", + " A: dict() object containing adjacency matrices for each message passing\n", + " msg_activations: dict() object containing activations before each message passing\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " input_dim=12, output_dim_id=6, output_dim_p4=6,\n", + " embedding_dim=2, hidden_dim1=2, hidden_dim2=2,\n", + " num_convs=2, space_dim=4, propagate_dim=2, k=8):\n", + " super(MLPF_GNN, self).__init__()\n", + "\n", + " # self.act = nn.ReLU\n", + " self.act = nn.ELU\n", + "\n", + " # (1) embedding\n", + " self.nn1 = nn.Sequential(\n", + " nn.Linear(input_dim, hidden_dim1),\n", + " self.act(),\n", + " nn.Linear(hidden_dim1, hidden_dim1),\n", + " self.act(),\n", + " nn.Linear(hidden_dim1, hidden_dim1),\n", + " self.act(),\n", + " nn.Linear(hidden_dim1, embedding_dim),\n", + " )\n", + "\n", + " self.conv = nn.ModuleList()\n", + " for i in range(num_convs):\n", + " self.conv.append(GravNetConv_LRP(embedding_dim, embedding_dim, space_dim, propagate_dim, k))\n", + "\n", + " # (3) DNN layer: classifiying pid\n", + " self.nn2 = nn.Sequential(\n", + " nn.Linear(input_dim + embedding_dim, hidden_dim2),\n", + " self.act(),\n", + " nn.Linear(hidden_dim2, hidden_dim2),\n", + " self.act(),\n", + " nn.Linear(hidden_dim2, hidden_dim2),\n", + " self.act(),\n", + " nn.Linear(hidden_dim2, output_dim_id),\n", + " )\n", + "\n", + " # (4) DNN layer: regressing p4\n", + " self.nn3 = nn.Sequential(\n", + " nn.Linear(input_dim + output_dim_id, hidden_dim2),\n", + " self.act(),\n", + " nn.Linear(hidden_dim2, hidden_dim2),\n", + " self.act(),\n", + " nn.Linear(hidden_dim2, hidden_dim2),\n", + " self.act(),\n", + " nn.Linear(hidden_dim2, output_dim_p4),\n", + " )\n", + "\n", + " def forward(self, batch):\n", + "\n", + " x0 = batch.x\n", + "\n", + " # embed the inputs\n", + " embedding = self.nn1(x0)\n", + "\n", + " # preform a series of graph convolutions\n", + " A = {}\n", + " msg_activations = {}\n", + " for num, conv in enumerate(self.conv):\n", + " embedding, A[f'conv.{num}'], msg_activations[f'conv.{num}'] = conv(embedding)\n", + "\n", + " # predict the pid's\n", + " preds_id = self.nn2(torch.cat([x0, embedding], axis=-1))\n", + "\n", + " # predict the p4's\n", + " preds_p4 = self.nn3(torch.cat([x0, preds_id], axis=-1))\n", + "\n", + " return torch.cat([preds_id, preds_p4], axis=-1), A, msg_activations\n", + "\n", + "\n", + "class GravNetConv_LRP(MessagePassing):\n", + " \"\"\"\n", + " Copied from pytorch_geometric source code, with the following edits\n", + " a. retrieve adjacency matrix (we call A), and the activations before the message passing step (we call msg_activations)\n", + " b. switched the execution of self.lin_s & self.lin_p so that the message passing step can substitute out of the box self.lin_s for lrp purposes\n", + " c. used reduce='sum' instead of reduce='mean' in the message passing\n", + " d. removed skip connection\n", + " \"\"\"\n", + "\n", + " def __init__(self, in_channels: int, out_channels: int,\n", + " space_dimensions: int, propagate_dimensions: int, k: int,\n", + " num_workers: int = 1, **kwargs):\n", + " super().__init__(flow='source_to_target', **kwargs)\n", + "\n", + " if knn is None:\n", + " raise ImportError('`GravNetConv` requires `torch-cluster`.')\n", + "\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.k = k\n", + " self.num_workers = num_workers\n", + "\n", + " self.lin_p = Linear(in_channels, propagate_dimensions)\n", + " self.lin_s = Linear(in_channels, space_dimensions)\n", + " self.lin_out = Linear(propagate_dimensions, out_channels)\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " self.lin_s.reset_parameters()\n", + " self.lin_p.reset_parameters()\n", + " self.lin_out.reset_parameters()\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, PairTensor],\n", + " batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:\n", + " \"\"\"\"\"\"\n", + "\n", + " is_bipartite: bool = True\n", + " if isinstance(x, Tensor):\n", + " x: PairTensor = (x, x)\n", + " is_bipartite = False\n", + "\n", + " if x[0].dim() != 2:\n", + " raise ValueError(\"Static graphs not supported in 'GravNetConv'\")\n", + "\n", + " b: PairOptTensor = (None, None)\n", + " if isinstance(batch, Tensor):\n", + " b = (batch, batch)\n", + " elif isinstance(batch, tuple):\n", + " assert batch is not None\n", + " b = (batch[0], batch[1])\n", + "\n", + " # embed the inputs before message passing\n", + " msg_activations = self.lin_p(x[0])\n", + "\n", + " # transform to the space dimension to build the graph\n", + " s_l: Tensor = self.lin_s(x[0])\n", + " s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l\n", + "\n", + " edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0])\n", + "\n", + " edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1)\n", + " edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread\n", + "\n", + " # return the adjacency matrix of the graph for lrp purposes\n", + " A = to_dense_adj(edge_index.to('cpu'), edge_attr=edge_weight.to('cpu'))[0] # adjacency matrix\n", + "\n", + " # message passing\n", + " out = self.propagate(edge_index, x=(msg_activations, None),\n", + " edge_weight=edge_weight,\n", + " size=(s_l.size(0), s_r.size(0)))\n", + "\n", + " return self.lin_out(out), A, msg_activations\n", + "\n", + " def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n", + " return x_j * edge_weight.unsqueeze(1)\n", + "\n", + " def aggregate(self, inputs: Tensor, index: Tensor,\n", + " dim_size: Optional[int] = None) -> Tensor:\n", + " out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,\n", + " reduce='sum')\n", + " return out_mean\n", + "\n", + " def __repr__(self) -> str:\n", + " return (f'{self.__class__.__name__}({self.in_channels}, '\n", + " f'{self.out_channels}, k={self.k})')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from X (the input to the tensorflow model, reshape and recast as conveninet pytorch format)\n", + "pytorch_X = torch.tensor(X[:1].reshape(-1,12)) # the slice [:1] picks up the first event" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# build a simple FCN and perform forward pass\n", + "model = MLPF_FCN()\n", + "model(pytorch_X);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# build a GNN and perform forward pass\n", + "batch = Batch(x = pytorch_X) # recall GNN takes a batch object\n", + "model = MLPF_GNN()\n", + "model(batch);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# defines lrp\n", + "\n", + "import pickle as pkl\n", + "import os.path as osp\n", + "import os\n", + "import sys\n", + "from glob import glob\n", + "\n", + "import torch\n", + "\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.nn import Sequential as Seq, Linear as Lin, ReLU\n", + "\n", + "\n", + "class LRP_MLPF():\n", + "\n", + " \"\"\"\n", + " Extends the LRP class to act on graph datasets and GNNs based on the Gravnet layer (e.g. the MLPF model, see models.MLPF)\n", + " The main trick is to realize that the \".lin_s\" layers in Gravnet are irrelevant for explanations so shall be skipped\n", + " The hack, however, is to substitute them precisely with the message_passing step\n", + "\n", + " Differences from standard LRP\n", + " a. Rscores become tensors/graphs of input features per output neuron instead of vectors\n", + " b. accomodates message passing steps by using the adjacency matrix as the weight matrix in standard LRP,\n", + " and redistributing Rscores over the other dimension (over nodes instead of features)\n", + " \"\"\"\n", + "\n", + " def __init__(self, device, model, epsilon):\n", + "\n", + " self.device = device\n", + " self.model = model.to(device)\n", + " self.epsilon = epsilon # for stability reasons in the lrp-epsilon rule (by default: a very small number)\n", + "\n", + " # check if the model has any skip connections to accomodate them\n", + " self.skip_connections = self.find_skip_connections()\n", + " self.msg_passing_layers = self.find_msg_passing_layers()\n", + "\n", + " \"\"\"\n", + " explanation functions\n", + " \"\"\"\n", + "\n", + " def explain(self, input, neuron_to_explain):\n", + " \"\"\"\n", + " Primary function to call on an LRP instance to start explaining predictions.\n", + " First, it registers hooks and runs a forward pass on the input.\n", + " Then, it attempts to explain the whole model by looping over the layers in the model and invoking the explain_single_layer function.\n", + "\n", + " Args:\n", + " input: tensor containing the input sample you wish to explain\n", + " neuron_to_explain: the index for a particular neuron in the output layer you wish to explain\n", + "\n", + " Returns:\n", + " R_tensor: a tensor/graph containing the relevance scores of the input graph for a particular output neuron\n", + " preds: the model predictions of the input (for further plotting/processing purposes only)\n", + " input: the input that was explained (for further plotting/processing purposes only)\n", + " \"\"\"\n", + "\n", + " # register forward hooks to retrieve intermediate activations\n", + " # in simple words, when the forward pass is called, the following dict() will be filled with (key, value) = (\"layer_name\", activations)\n", + " activations = {}\n", + "\n", + " def get_activation(name):\n", + " def hook(model, input, output):\n", + " activations[name] = input[0]\n", + " return hook\n", + "\n", + " for name, module in self.model.named_modules():\n", + " # unfold any containers so as to register hooks only for their child modules (equivalently we are demanding type(module) != nn.Sequential))\n", + " if ('Linear' in str(type(module))) or ('activation' in str(type(module))) or ('BatchNorm1d' in str(type(module))):\n", + " module.register_forward_hook(get_activation(name))\n", + "\n", + " # run a forward pass\n", + " self.model.eval()\n", + " preds, self.A, self.msg_activations = self.model(input.to(self.device))\n", + "\n", + " # get the activations\n", + " self.activations = activations\n", + " self.num_layers = len(activations.keys())\n", + " self.in_features_dim = self.name2layer(list(activations.keys())[0]).in_features\n", + "\n", + " print(f'Total number of layers: {self.num_layers}')\n", + "\n", + " # initialize Rscores for skip connections (in case there are any)\n", + " if len(self.skip_connections) != 0:\n", + " self.skip_connections_relevance = 0\n", + "\n", + " # initialize the Rscores tensor using the output predictions\n", + " Rscores = preds[:, neuron_to_explain].reshape(-1, 1).detach()\n", + "\n", + " # build the Rtensor which is going to be a whole graph of Rscores per node\n", + " R_tensor = torch.zeros([Rscores.shape[0], Rscores.shape[0], Rscores.shape[1]]).to(self.device)\n", + " for node in range(R_tensor.shape[0]):\n", + " R_tensor[node][node] = Rscores[node]\n", + "\n", + " # loop over layers in the model to propagate Rscores backward\n", + " for layer_index in range(self.num_layers, 0, -1):\n", + " R_tensor = self.explain_single_layer(R_tensor, layer_index, neuron_to_explain)\n", + "\n", + " print(\"Finished explaining all layers.\")\n", + "\n", + " if len(self.skip_connections) != 0:\n", + " return R_tensor + self.skip_connections_relevance, preds, input\n", + "\n", + " return R_tensor, preds, input\n", + "\n", + " def explain_single_layer(self, R_tensor_old, layer_index, neuron_to_explain):\n", + " \"\"\"\n", + " Attempts to explain a single layer in the model by propagating Rscores backwards using the lrp-epsilon rule.\n", + "\n", + " Args:\n", + " R_tensor_old: a tensor/graph containing the Rscores, of the current layer, to be propagated backwards\n", + " layer_index: index that corresponds to the position of the layer in the model (see helper functions)\n", + " neuron_to_explain: the index for a particular neuron in the output layer to explain\n", + "\n", + " Returns:\n", + " R_tensor_new: a tensor/graph containing the computed Rscores of the previous layer\n", + " \"\"\"\n", + "\n", + " # get layer information\n", + " layer_name = self.index2name(layer_index)\n", + " layer = self.name2layer(layer_name)\n", + "\n", + " # get layer activations (depends wether it's a message passing step)\n", + " if layer_name in self.msg_passing_layers.keys():\n", + " print(f\"Explaining layer {self.num_layers+1-layer_index}/{self.num_layers}: MessagePassing layer\")\n", + " input = self.msg_activations[layer_name[:-6]].to(self.device).detach()\n", + " msg_passing_layer = True\n", + " else:\n", + " print(f\"Explaining layer {self.num_layers+1-layer_index}/{self.num_layers}: {layer}\")\n", + " input = self.activations[layer_name].to(self.device).detach()\n", + " msg_passing_layer = False\n", + "\n", + " # run lrp\n", + " if 'Linear' in str(layer):\n", + " R_tensor_new = self.eps_rule(self, layer, layer_name, input, R_tensor_old, neuron_to_explain, msg_passing_layer)\n", + " print('- Finished computing Rscores')\n", + " return R_tensor_new\n", + " else:\n", + " if 'activation' in str(layer):\n", + " print(f\"- skipping layer because it's an activation layer\")\n", + " elif 'BatchNorm1d' in str(layer):\n", + " print(f\"- skipping layer because it's a BatchNorm layer\")\n", + " print(f\"- Rscores do not need to be computed\")\n", + " return R_tensor_old\n", + "\n", + " \"\"\"\n", + " lrp-epsilon rule\n", + " \"\"\"\n", + "\n", + " @staticmethod\n", + " def eps_rule(self, layer, layer_name, x, R_tensor_old, neuron_to_explain, msg_passing_layer):\n", + " \"\"\"\n", + " Implements the lrp-epsilon rule presented in the following reference: https://doi.org/10.1007/978-3-030-28954-6_10.\n", + "\n", + " Can accomodate message_passing layers if the adjacency matrix and the activations before the message_passing are provided.\n", + " The trick (or as we like to call it, the message_passing hack) is in\n", + " a. using the adjacency matrix as the weight matrix in the standard lrp rule\n", + " b. transposing the activations to distribute the Rscores over the other dimension (over nodes instead of features)\n", + "\n", + " Args:\n", + " layer: a torch.nn module with a corresponding weight matrix W\n", + " x: vector containing the activations of the previous layer\n", + " R_tensor_old: a tensor/graph containing the Rscores, of the current layer, to be propagated backwards\n", + " neuron_to_explain: the index for a particular neuron in the output layer to explain\n", + "\n", + " Returns:\n", + " R_tensor_new: a tensor/graph containing the computed Rscores of the previous layer\n", + " \"\"\"\n", + "\n", + " torch.cuda.empty_cache()\n", + "\n", + " if msg_passing_layer: # message_passing hack\n", + " x = torch.transpose(x, 0, 1) # transpose the activations to distribute the Rscores over the other dimension (over nodes instead of features)\n", + " W = self.A[layer_name[:-6]].detach().to(self.device) # use the adjacency matrix as the weight matrix\n", + " else:\n", + " W = layer.weight.detach() # get weight matrix\n", + " W = torch.transpose(W, 0, 1) # sanity check of forward pass: (torch.matmul(x, W) + layer.bias) == layer(x)\n", + "\n", + " # for the output layer, pick the part of the weight matrix connecting only to the neuron you're attempting to explain\n", + " if layer == list(self.model.modules())[-1]:\n", + " W = W[:, neuron_to_explain].reshape(-1, 1)\n", + "\n", + " # (1) compute the denominator\n", + " denominator = torch.matmul(x, W) + self.epsilon\n", + " # (2) scale the Rscores\n", + " if msg_passing_layer: # message_passing hack\n", + " R_tensor_old = torch.transpose(R_tensor_old, 1, 2)\n", + " scaledR = R_tensor_old / denominator\n", + " # (3) compute the new Rscores\n", + " R_tensor_new = torch.matmul(scaledR, torch.transpose(W, 0, 1)) * x\n", + "\n", + " # checking conservation of Rscores for a given random node (# 17)\n", + " rtol = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]\n", + " for tol in rtol:\n", + " if (torch.allclose(R_tensor_new[17].sum(), R_tensor_old[17].sum(), rtol=tol)):\n", + " print(f'- Rscores are conserved up to relative tolerance {str(tol)}')\n", + " break\n", + "\n", + " if layer in self.skip_connections:\n", + " # set aside the relevance of the input_features in the skip connection\n", + " # recall: it is assumed that the skip connections are defined in the following order torch.cat[(input_features, ...)] )\n", + " self.skip_connections_relevance = self.skip_connections_relevance + R_tensor_new[:, :, :self.in_features_dim]\n", + " return R_tensor_new[:, :, self.in_features_dim:]\n", + "\n", + " if msg_passing_layer: # message_passing hack\n", + " return torch.transpose(R_tensor_new, 1, 2)\n", + "\n", + " return R_tensor_new\n", + "\n", + " \"\"\"\n", + " helper functions\n", + " \"\"\"\n", + "\n", + " def index2name(self, layer_index):\n", + " \"\"\"\n", + " Given the index of a layer (e.g. 3) returns the name of the layer (e.g. .nn1.3)\n", + " \"\"\"\n", + " layer_name = list(self.activations.keys())[layer_index - 1]\n", + " return layer_name\n", + "\n", + " def name2layer(self, layer_name):\n", + " \"\"\"\n", + " Given the name of a layer (e.g. .nn1.3) returns the corresponding torch module (e.g. Linear(...))\n", + " \"\"\"\n", + " for name, module in self.model.named_modules():\n", + " if layer_name == name:\n", + " return module\n", + "\n", + " def find_skip_connections(self):\n", + " \"\"\"\n", + " Given a torch model, retuns a list of layers with skip connections... the elements are torch modules (e.g. Linear(...))\n", + " \"\"\"\n", + " explainable_layers = []\n", + " for name, module in self.model.named_modules():\n", + " if 'lin_s' in name: # for models that are based on Gravnet, skip the lin_s layers\n", + " continue\n", + " if ('Linear' in str(type(module))):\n", + " explainable_layers.append(module)\n", + "\n", + " skip_connections = []\n", + " for layer_index in range(len(explainable_layers) - 1):\n", + " if explainable_layers[layer_index].out_features != explainable_layers[layer_index + 1].in_features:\n", + " skip_connections.append(explainable_layers[layer_index + 1])\n", + "\n", + " return skip_connections\n", + "\n", + " def find_msg_passing_layers(self):\n", + " \"\"\"\n", + " Returns a list of \".lin_s\" layers from model.named_modules() that shall be substituted with message passing\n", + " \"\"\"\n", + " msg_passing_layers = {}\n", + " for name, module in self.model.named_modules():\n", + " if 'lin_s' in name: # for models that are based on Gravnet, replace the .lin_s layers with message_passing\n", + " msg_passing_layers[name] = {}\n", + "\n", + " return msg_passing_layers\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test lrp for the FCN model\n", + "model = MLPF_FCN()\n", + "\n", + "lrp_instance = LRP_MLPF('cpu', model, epsilon=1e-9)\n", + "Rtensor, pred, inputt = lrp_instance.explain(pytorch_X, neuron_to_explain=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test lrp for the GNN model\n", + "model = MLPF_GNN()\n", + "\n", + "lrp_instance = LRP_MLPF('cpu', model, epsilon=1e-9)\n", + "Rtensor, pred, inputt = lrp_instance.explain(batch, neuron_to_explain=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Back to the tensorflow setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Get the first event\n", + "input_classes = np.unique(X[:, :, 0].flatten())\n", + "output_classes = np.unique(y[:, :, 0].flatten())\n", + "num_output_classes = len(output_classes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def transform_target(y):\n", + " return {\n", + " \"cls\": tf.one_hot(tf.cast(y[:, :, 0], tf.int32), num_output_classes),\n", + " \"charge\": y[:, :, 1:2],\n", + " \"pt\": y[:, :, 2:3],\n", + " \"eta\": y[:, :, 3:4],\n", + " \"sin_phi\": y[:, :, 4:5],\n", + " \"cos_phi\": y[:, :, 5:6],\n", + " \"energy\": y[:, :, 6:7],\n", + " }\n", + "yt = transform_target(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tfmodel.model import PFNetDense" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "msk_true_particle = y[:, :, 0]!=0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.unique(y[msk_true_particle][:, 0], return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(yt[\"pt\"][msk_true_particle].flatten(), bins=100);\n", + "plt.xlabel(\"pt\")\n", + "plt.yscale(\"log\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(yt[\"eta\"][msk_true_particle].flatten(), bins=100);\n", + "plt.xlabel(\"eta\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(yt[\"sin_phi\"][msk_true_particle].flatten(), bins=100);\n", + "plt.xlabel(\"sin phi\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(yt[\"cos_phi\"][msk_true_particle].flatten(), bins=100);\n", + "plt.xlabel(\"cos phi\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(yt[\"energy\"][msk_true_particle].flatten(), bins=100);\n", + "plt.xlabel(\"energy\")\n", + "plt.yscale(\"log\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = PFNetDense(\n", + " num_input_classes=len(input_classes),\n", + " num_output_classes=len(output_classes),\n", + " activation=tf.nn.elu,\n", + " hidden_dim=128,\n", + " bin_size=128,\n", + " input_encoding=\"default\",\n", + " multi_output=True\n", + ")\n", + "\n", + "# #temporal weight mode means each input element in the event can get a separate weight\n", + "model.compile(\n", + " loss={\n", + " \"cls\": tf.keras.losses.CategoricalCrossentropy(from_logits=False),\n", + " \"charge\": tf.keras.losses.MeanSquaredError(),\n", + " \"pt\": tf.keras.losses.MeanSquaredError(),\n", + " \"energy\": tf.keras.losses.MeanSquaredError(),\n", + " \"eta\": tf.keras.losses.MeanSquaredError(),\n", + " \"sin_phi\": tf.keras.losses.MeanSquaredError(),\n", + " \"cos_phi\": tf.keras.losses.MeanSquaredError()\n", + " },\n", + " optimizer=\"adam\",\n", + " sample_weight_mode=\"temporal\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model(X[:1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(X, yt, epochs=2, batch_size=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ypred = model.predict(X, batch_size=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#index of the class prediction output values\n", + "pred_id_offset = len(output_classes)\n", + "ypred_ids_raw = ypred[\"cls\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sklearn.metrics.confusion_matrix(\n", + " np.argmax(ypred_ids_raw, axis=-1).flatten(),\n", + " np.argmax(yt[\"cls\"], axis=-1).flatten(), labels=output_classes\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "msk_particles = (X[:, :, 0]!=0)\n", + "plt.scatter(\n", + " ypred[\"eta\"][msk_particles].flatten(),\n", + " yt[\"eta\"][msk_particles].flatten(), marker=\".\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}