Skip to content

Commit

Permalink
restore onnx export in pytorch (jpata#265)
Browse files Browse the repository at this point in the history
* restore onnx export

* configurable stable sort

* use DDP correctly
  • Loading branch information
jpata authored Oct 31, 2023
1 parent 4b00985 commit deef771
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 53 deletions.
40 changes: 19 additions & 21 deletions mlpf/pyg/gnn_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def point_wise_feed_forward_network(
return nn.Sequential(*layers)


def split_indices_to_bins_batch(cmul, nbins, bin_size, msk):
def split_indices_to_bins_batch(cmul, nbins, bin_size, msk, stable_sort=False):
a = torch.argmax(cmul, axis=-1)

# This gives a CUDA error for some reason
Expand All @@ -45,7 +45,12 @@ def split_indices_to_bins_batch(cmul, nbins, bin_size, msk):
b[~msk] = nbins - 1

bin_idx = a + b
bins_split = torch.reshape(torch.argsort(bin_idx, stable=True), (cmul.shape[0], nbins, bin_size))
if stable_sort:
bins_split = torch.argsort(bin_idx, stable=True)
else:
# for ONNX export to work, stable must not be provided at all as an argument
bins_split = torch.argsort(bin_idx)
bins_split = bins_split.reshape((cmul.shape[0], nbins, bin_size))
return bins_split


Expand Down Expand Up @@ -184,6 +189,7 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node
self.max_num_bins = max_num_bins
self.bin_size = bin_size
self.kernel = kernel
self.stable_sort = False

# generate the LSH codebook for random rotations (num_features, max_num_bins/2)
self.codebook_random_rotations = nn.Parameter(
Expand All @@ -193,7 +199,7 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node

def forward(self, x_msg, x_node, msk, training=False):
shp = x_msg.shape
n_points = shp[1]
n_points = torch.tensor(shp[1])

if n_points % self.bin_size != 0:
raise Exception("Number of elements per event must be exactly divisible by the bin size")
Expand All @@ -202,25 +208,17 @@ def forward(self, x_msg, x_node, msk, training=False):
# n_points must be divisible by bin_size exactly due to the use of reshape
n_bins = torch.floor_divide(n_points, self.bin_size)

msk_f = torch.unsqueeze(msk, -1)
if n_bins > 1:
mul = torch.linalg.matmul(
x_msg,
self.codebook_random_rotations[:, : torch.maximum(torch.tensor(1), n_bins // 2)],
)
cmul = torch.concatenate([mul, -mul], axis=-1)
bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk)
mul = torch.linalg.matmul(
x_msg,
self.codebook_random_rotations[:, : torch.maximum(torch.tensor(1), n_bins // 2)],
)
cmul = torch.concatenate([mul, -mul], axis=-1)
bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk, self.stable_sort)

# replaced tf.gather with torch.vmap, indexing and reshape
x_msg_binned, x_features_binned, msk_f_binned = split_msk_and_msg(
bins_split, cmul, x_msg, x_node, msk, n_bins, self.bin_size
)
else:
x_msg_binned = torch.unsqueeze(x_msg, axis=1)
x_features_binned = torch.unsqueeze(x_node, axis=1)
msk_f_binned = torch.unsqueeze(msk_f, axis=1)
shp = x_msg_binned.shape
bins_split = torch.zeros([shp[0], shp[1], shp[2]], dtype=torch.int32)
# replaced tf.gather with torch.vmap, indexing and reshape
x_msg_binned, x_features_binned, msk_f_binned = split_msk_and_msg(
bins_split, cmul, x_msg, x_node, msk, n_bins, self.bin_size
)

# Run the node-to-node kernel (distance computation / graph building / attention)
dm = self.kernel(x_msg_binned, msk_f_binned, training=training)
Expand Down
7 changes: 4 additions & 3 deletions mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, je
for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)):
if conv_type != "gravnet":
X_pad, mask = torch_geometric.utils.to_dense_batch(batch.X, batch.batch)
batch_pad = Batch(X=X_pad, mask=mask)
ypred = model(batch_pad.to(rank))
batch_pad = Batch(X=X_pad, mask=mask).to(rank)
ypred = model(batch_pad.X, batch_pad.mask)
ypred = ypred[0][mask], ypred[1][mask], ypred[2][mask]
else:
ypred = model(batch.to(rank))
_batch = batch.to(rank)
ypred = model(_batch.X, _batch.batch)

ygen = unpack_target(batch.ygen)
ycand = unpack_target(batch.ycand)
Expand Down
29 changes: 14 additions & 15 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ def __init__(
# elementwise DNN for node charge regression, classes (-1, 0, 1)
self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout)

def forward(self, event):
# unfold the Batch object
input_ = event.X.float()
def forward_batch(self, batched_events):
batch_or_mask = batched_events.batch if self.conv_type == "gravnet" else batched_events.mask
return self(batched_events.X, batch_or_mask)

def forward(self, X_features, batch_or_mask):

embeddings_id, embeddings_reg = [], []
if self.num_convs != 0:

embedding = self.nn0(X_features)
if self.conv_type == "gravnet":
embedding = self.nn0(input_)

batch_idx = event.batch
batch_idx = batch_or_mask
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
Expand All @@ -141,8 +141,7 @@ def forward(self, event):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
else:
mask = event.mask
embedding = self.nn0(input_)
mask = batch_or_mask
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
out_padded = conv(conv_input, ~mask)
Expand All @@ -152,11 +151,11 @@ def forward(self, event):
out_padded = conv(conv_input, ~mask)
embeddings_reg.append(out_padded)

embedding_id = torch.cat([input_] + embeddings_id, axis=-1)
embedding_id = torch.cat([X_features] + embeddings_id, axis=-1)
preds_id = self.nn_id(embedding_id)

# regression
embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1)
embedding_reg = torch.cat([X_features] + embeddings_reg + [preds_id], axis=-1)

# do some sanity checks on the PFElement input data
# assert torch.all(torch.abs(input_[:, 3]) <= 1.0) # sin_phi
Expand All @@ -166,10 +165,10 @@ def forward(self, event):

# predict the 4-momentum, add it to the (pt, eta, sin phi, cos phi, E) of the input PFelement
# the feature order is defined in fcc/postprocessing.py -> track_feature_order, cluster_feature_order
preds_pt = self.nn_pt(embedding_reg) + input_[..., 1:2]
preds_eta = self.nn_eta(embedding_reg) + input_[..., 2:3]
preds_phi = self.nn_phi(embedding_reg) + input_[..., 3:5]
preds_energy = self.nn_energy(embedding_reg) + input_[..., 5:6]
preds_pt = self.nn_pt(embedding_reg) + X_features[..., 1:2]
preds_eta = self.nn_eta(embedding_reg) + X_features[..., 2:3]
preds_phi = self.nn_phi(embedding_reg) + X_features[..., 3:5]
preds_energy = self.nn_energy(embedding_reg) + X_features[..., 5:6]
preds_momentum = torch.cat([preds_pt, preds_eta, preds_phi, preds_energy], axis=-1)
pred_charge = self.nn_charge(embedding_reg)

Expand Down
8 changes: 7 additions & 1 deletion mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train):

ygen = unpack_target(batch.ygen)

ypred = model(batch)
if world_size > 1:
conv_type = model.module.conv_type
else:
conv_type = model.conv_type

batchidx_or_mask = batch.batch if conv_type == "gravnet" else batch.mask
ypred = model(batch.X, batchidx_or_mask)
ypred = unpack_predictions(ypred)

if is_train:
Expand Down
18 changes: 9 additions & 9 deletions mlpf/pyg_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,21 +238,21 @@ def run(rank, world_size, config, args, outdir, logfile):

if args.export_onnx:
try:
dummy_features = torch.randn(256, model_kwargs["input_dim"], rank=rank)
dummy_batch = torch.zeros(256, dtype=torch.int64, rank=rank)
dummy_features = torch.randn(1, 640, model_kwargs["input_dim"], device=rank)
dummy_mask = torch.zeros(1, 640, dtype=torch.bool, device=rank)
torch.onnx.export(
model,
(dummy_features, dummy_batch),
(dummy_features, dummy_mask),
"test.onnx",
verbose=True,
input_names=["features", "batch"],
input_names=["features", "mask"],
output_names=["id", "momentum", "charge"],
dynamic_axes={
"features": {0: "num_elements"},
"batch": [0],
"id": [0],
"momentum": [0],
"charge": [0],
"features": {0: "num_batch", 1: "num_elements"},
"mask": [0, 1],
"id": [0, 1],
"momentum": [0, 1],
"charge": [0, 1],
},
)
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions scripts/local_test_pyg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ mkdir -p experiments
tfds build mlpf/heptfds/cms_pf/ttbar --manual_dir ./local_test_data

python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots

python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots --conv-type gnn_lsh --export-onnx
7 changes: 3 additions & 4 deletions scripts/tallinn/rtx/pytorch.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#!/bin/bash
#SBATCH --partition gpu
#SBATCH --gres gpu:rtx:2
#SBATCH --gres gpu:rtx:8
#SBATCH --mem-per-gpu 40G
#SBATCH -o logs/slurm-%x-%j-%N.out

IMG=/home/software/singularity/pytorch.simg
cd ~/particleflow

#TF training
singularity exec -B /scratch/persistent --nv \
--env PYTHONPATH=hep_tfds \
$IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1 \
$IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1,2,3,4,5,6,7 \
--data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pyg-cms-small.yaml \
--train --conv-type gnn_lsh --num-epochs 10 --ntrain 500 --ntest 500 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10
--train --test --make-plots --export-onnx --conv-type gnn_lsh --num-epochs 10 --ntrain 1000 --ntest 1000 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10
2 changes: 2 additions & 0 deletions tests/test_torch_and_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def test_MessageBuildingLayerLSH(self):
from mlpf.pyg.gnn_lsh import MessageBuildingLayerLSH as MessageBuildingLayerLSHTorch

nn2 = MessageBuildingLayerLSHTorch(distance_dim=128, bin_size=64)
# for testing, use the slower TF-like stable sort
nn2.stable_sort = True

x_dist = np.random.normal(size=(2, 256, 128)).astype(np.float32)
x_node = np.random.normal(size=(2, 256, 32)).astype(np.float32)
Expand Down

0 comments on commit deef771

Please sign in to comment.