Skip to content

Commit

Permalink
making the 3d-padded models more efficient in pytorch (jpata#256)
Browse files Browse the repository at this point in the history
* initial wip

* move padding to collate

* avoid compile

* specify default conv type

* train and valid separately

* add saving weight

* update submission script

* valid only on rank0

* use_cuda
  • Loading branch information
jpata authored Oct 26, 2023
1 parent 27d9798 commit ae04258
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 250 deletions.
35 changes: 25 additions & 10 deletions mlpf/pyg/PFDataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import List, Optional
from types import SimpleNamespace

import tensorflow_datasets as tfds
import torch
import torch.utils.data
from torch import Tensor
import torch_geometric
from torch_geometric.data import Batch, Data
from torch_geometric.data.data import BaseData


class PFDataset:
"""Builds a DataSource from tensorflow datasets."""

def __init__(self, data_dir, name, split, keys_to_get, num_samples=None):
def __init__(self, data_dir, name, split, keys_to_get, pad_3d=True, num_samples=None):
"""
Args
data_dir: path to tensorflow_datasets (e.g. `../data/tensorflow_datasets/`)
Expand All @@ -29,7 +30,6 @@ def __init__(self, data_dir, name, split, keys_to_get, num_samples=None):

# to make dataset_info pickable
tmp = self.ds.dataset_info
from types import SimpleNamespace

self.ds.dataset_info = SimpleNamespace()
self.ds.dataset_info.name = tmp.name
Expand All @@ -39,6 +39,8 @@ def __init__(self, data_dir, name, split, keys_to_get, num_samples=None):
# any selection of ["X", "ygen", "ycand"] to retrieve
self.keys_to_get = keys_to_get

self.pad_3d = pad_3d

if num_samples:
self.ds = torch.utils.data.Subset(self.ds, range(num_samples))

Expand All @@ -50,7 +52,7 @@ def get_distributed_sampler(self):
sampler = torch.utils.data.distributed.DistributedSampler(self.ds)
return sampler

def get_loader(self, batch_size, world_size, num_workers=0, prefetch_factor=None):
def get_loader(self, batch_size, world_size, rank, use_cuda=False, num_workers=0, prefetch_factor=None):
if (num_workers > 0) and (prefetch_factor is None):
prefetch_factor = 2 # default prefetch_factor when num_workers>0

Expand All @@ -62,10 +64,12 @@ def get_loader(self, batch_size, world_size, num_workers=0, prefetch_factor=None
return DataLoader(
self.ds,
batch_size=batch_size,
collate_fn=Collater(self.keys_to_get),
collate_fn=Collater(self.keys_to_get, pad_3d=self.pad_3d),
sampler=sampler,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=use_cuda,
pin_memory_device="cuda:{}".format(rank) if use_cuda else "",
)

def __len__(self):
Expand Down Expand Up @@ -109,10 +113,12 @@ def __init__(
class Collater:
"""Based on the Collater found on torch_geometric docs we build our own."""

def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None):
def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_bin_size=640, pad_3d=True):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
self.keys_to_get = keys_to_get
self.pad_bin_size = pad_bin_size
self.pad_3d = pad_3d

def __call__(self, inputs):
num_samples_in_batch = len(inputs)
Expand All @@ -125,12 +131,21 @@ def __call__(self, inputs):
batch[ev][elem_key] = Tensor(inputs[ev][elem_key])
batch[ev]["batch"] = torch.tensor([ev] * len(inputs[ev][elem_key]))

elem = batch[0]
ret = Batch.from_data_list(batch, self.follow_batch, self.exclude_keys)

if not self.pad_3d:
return ret
else:
ret = {k: torch_geometric.utils.to_dense_batch(getattr(ret, k), ret.batch) for k in elem_keys}

ret["mask"] = ret["X"][1]

if isinstance(elem, BaseData):
return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys)
# remove the mask from each element
for k in elem_keys:
ret[k] = ret[k][0]

raise TypeError(f"DataLoader found invalid type: {type(elem)}")
ret = Batch(**ret)
return ret


def my_getitem(self, vals):
Expand Down
41 changes: 17 additions & 24 deletions mlpf/pyg/gnn_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ def point_wise_feed_forward_network(
return nn.Sequential(*layers)


# @torch.compile
def index_dim(a, b):
return a[b]


def split_indices_to_bins_batch(cmul, nbins, bin_size, msk):
a = torch.argmax(cmul, axis=-1)

Expand Down Expand Up @@ -143,7 +138,6 @@ def forward(self, x_msg_binned, msk, training=False):
return dm


@torch.compile
def split_msk_and_msg(bins_split, cmul, x_msg, x_node, msk, n_bins, bin_size):
bins_split_2 = torch.reshape(bins_split, (bins_split.shape[0], bins_split.shape[1] * bins_split.shape[2]))

Expand All @@ -164,6 +158,23 @@ def split_msk_and_msg(bins_split, cmul, x_msg, x_node, msk, n_bins, bin_size):
return x_msg_binned, x_features_binned, msk_f_binned


def reverse_lsh(bins_split, points_binned_enc):
shp = points_binned_enc.shape
batch_dim = shp[0]
n_points = shp[1] * shp[2]
n_features = shp[-1]

bins_split_flat = torch.reshape(bins_split, (batch_dim, n_points))
points_binned_enc_flat = torch.reshape(points_binned_enc, (batch_dim, n_points, n_features))

ret = torch.zeros(batch_dim, n_points, n_features, device=points_binned_enc.device)
for ibatch in range(batch_dim):
# torch._assert(torch.min(bins_split_flat[ibatch]) >= 0, "reverse_lsh n_points min")
# torch._assert(torch.max(bins_split_flat[ibatch]) < n_points, "reverse_lsh n_points max")
ret[ibatch][bins_split_flat[ibatch]] = points_binned_enc_flat[ibatch]
return ret


class MessageBuildingLayerLSH(nn.Module):
def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=NodePairGaussianKernel(), **kwargs):
self.initializer = kwargs.pop("initializer", "random_normal")
Expand Down Expand Up @@ -227,24 +238,6 @@ def forward(self, x_msg, x_node, msk, training=False):
return bins_split, x_features_binned, dm, msk_f_binned


@torch.compile
def reverse_lsh(bins_split, points_binned_enc):
shp = points_binned_enc.shape
batch_dim = shp[0]
n_points = shp[1] * shp[2]
n_features = shp[-1]

bins_split_flat = torch.reshape(bins_split, (batch_dim, n_points))
points_binned_enc_flat = torch.reshape(points_binned_enc, (batch_dim, n_points, n_features))

ret = torch.zeros(batch_dim, n_points, n_features, device=points_binned_enc.device)
for ibatch in range(batch_dim):
torch._assert(torch.min(bins_split_flat[ibatch]) >= 0, "reverse_lsh n_points min")
torch._assert(torch.max(bins_split_flat[ibatch]) < n_points, "reverse_lsh n_points max")
ret[ibatch][bins_split_flat[ibatch]] = points_binned_enc_flat[ibatch]
return ret


class CombinedGraphLayer(nn.Module):
def __init__(self, *args, **kwargs):
self.inout_dim = kwargs.pop("inout_dim")
Expand Down
36 changes: 9 additions & 27 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.utils
from torch_geometric.nn.conv import GravNetConv

from .gnn_lsh import CombinedGraphLayer
Expand Down Expand Up @@ -53,12 +51,6 @@ def ffn(input_dim, output_dim, width, act, dropout):
)


@torch.compile
def unpad(data_padded, mask):
A = data_padded[mask]
return A


class MLPF(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -132,21 +124,14 @@ def __init__(
def forward(self, event):
# unfold the Batch object
input_ = event.X.float()
batch_idx = event.batch

embeddings_id, embeddings_reg = [], []
if self.num_convs != 0:
embedding = self.nn0(input_)

if self.conv_type != "gravnet":
_, num_nodes = torch.unique(batch_idx, return_counts=True)
max_num_nodes = torch.max(num_nodes).cpu()
max_num_nodes_padded = ((max_num_nodes // self.bin_size) + 1) * self.bin_size
embedding, mask = torch_geometric.utils.to_dense_batch(
embedding, batch_idx, max_num_nodes=max_num_nodes_padded
)

if self.conv_type == "gravnet":
embedding = self.nn0(input_)

batch_idx = event.batch
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
Expand All @@ -155,6 +140,8 @@ def forward(self, event):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
else:
mask = event.mask
embedding = self.nn0(input_)
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
out_padded = conv(conv_input, ~mask)
Expand All @@ -164,11 +151,6 @@ def forward(self, event):
out_padded = conv(conv_input, ~mask)
embeddings_reg.append(out_padded)

if self.conv_type != "gravnet":
embeddings_id = [unpad(emb, mask) for emb in embeddings_id]
embeddings_reg = [unpad(emb, mask) for emb in embeddings_reg]

# classification
embedding_id = torch.cat([input_] + embeddings_id, axis=-1)
preds_id = self.nn_id(embedding_id)

Expand All @@ -183,10 +165,10 @@ def forward(self, event):

# predict the 4-momentum, add it to the (pt, eta, sin phi, cos phi, E) of the input PFelement
# the feature order is defined in fcc/postprocessing.py -> track_feature_order, cluster_feature_order
preds_pt = self.nn_pt(embedding_reg) + input_[:, 1:2]
preds_eta = self.nn_eta(embedding_reg) + input_[:, 2:3]
preds_phi = self.nn_phi(embedding_reg) + input_[:, 3:5]
preds_energy = self.nn_energy(embedding_reg) + input_[:, 5:6]
preds_pt = self.nn_pt(embedding_reg) + input_[..., 1:2]
preds_eta = self.nn_eta(embedding_reg) + input_[..., 2:3]
preds_phi = self.nn_phi(embedding_reg) + input_[..., 3:5]
preds_energy = self.nn_energy(embedding_reg) + input_[..., 5:6]
preds_momentum = torch.cat([preds_pt, preds_eta, preds_phi, preds_energy], axis=-1)
pred_charge = self.nn_charge(embedding_reg)

Expand Down
Loading

0 comments on commit ae04258

Please sign in to comment.