From deef771547a0a5df5f1f1e0e92d9e7315ba81bb6 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Tue, 31 Oct 2023 09:50:06 +0200 Subject: [PATCH] restore onnx export in pytorch (#265) * restore onnx export * configurable stable sort * use DDP correctly --- mlpf/pyg/gnn_lsh.py | 40 ++++++++++++++++------------------ mlpf/pyg/inference.py | 7 +++--- mlpf/pyg/mlpf.py | 29 ++++++++++++------------ mlpf/pyg/training.py | 8 ++++++- mlpf/pyg_pipeline.py | 18 +++++++-------- scripts/local_test_pyg.sh | 2 ++ scripts/tallinn/rtx/pytorch.sh | 7 +++--- tests/test_torch_and_tf.py | 2 ++ 8 files changed, 60 insertions(+), 53 deletions(-) diff --git a/mlpf/pyg/gnn_lsh.py b/mlpf/pyg/gnn_lsh.py index f3a375bda..03cf15498 100644 --- a/mlpf/pyg/gnn_lsh.py +++ b/mlpf/pyg/gnn_lsh.py @@ -33,7 +33,7 @@ def point_wise_feed_forward_network( return nn.Sequential(*layers) -def split_indices_to_bins_batch(cmul, nbins, bin_size, msk): +def split_indices_to_bins_batch(cmul, nbins, bin_size, msk, stable_sort=False): a = torch.argmax(cmul, axis=-1) # This gives a CUDA error for some reason @@ -45,7 +45,12 @@ def split_indices_to_bins_batch(cmul, nbins, bin_size, msk): b[~msk] = nbins - 1 bin_idx = a + b - bins_split = torch.reshape(torch.argsort(bin_idx, stable=True), (cmul.shape[0], nbins, bin_size)) + if stable_sort: + bins_split = torch.argsort(bin_idx, stable=True) + else: + # for ONNX export to work, stable must not be provided at all as an argument + bins_split = torch.argsort(bin_idx) + bins_split = bins_split.reshape((cmul.shape[0], nbins, bin_size)) return bins_split @@ -184,6 +189,7 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node self.max_num_bins = max_num_bins self.bin_size = bin_size self.kernel = kernel + self.stable_sort = False # generate the LSH codebook for random rotations (num_features, max_num_bins/2) self.codebook_random_rotations = nn.Parameter( @@ -193,7 +199,7 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node def forward(self, x_msg, x_node, msk, training=False): shp = x_msg.shape - n_points = shp[1] + n_points = torch.tensor(shp[1]) if n_points % self.bin_size != 0: raise Exception("Number of elements per event must be exactly divisible by the bin size") @@ -202,25 +208,17 @@ def forward(self, x_msg, x_node, msk, training=False): # n_points must be divisible by bin_size exactly due to the use of reshape n_bins = torch.floor_divide(n_points, self.bin_size) - msk_f = torch.unsqueeze(msk, -1) - if n_bins > 1: - mul = torch.linalg.matmul( - x_msg, - self.codebook_random_rotations[:, : torch.maximum(torch.tensor(1), n_bins // 2)], - ) - cmul = torch.concatenate([mul, -mul], axis=-1) - bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk) + mul = torch.linalg.matmul( + x_msg, + self.codebook_random_rotations[:, : torch.maximum(torch.tensor(1), n_bins // 2)], + ) + cmul = torch.concatenate([mul, -mul], axis=-1) + bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk, self.stable_sort) - # replaced tf.gather with torch.vmap, indexing and reshape - x_msg_binned, x_features_binned, msk_f_binned = split_msk_and_msg( - bins_split, cmul, x_msg, x_node, msk, n_bins, self.bin_size - ) - else: - x_msg_binned = torch.unsqueeze(x_msg, axis=1) - x_features_binned = torch.unsqueeze(x_node, axis=1) - msk_f_binned = torch.unsqueeze(msk_f, axis=1) - shp = x_msg_binned.shape - bins_split = torch.zeros([shp[0], shp[1], shp[2]], dtype=torch.int32) + # replaced tf.gather with torch.vmap, indexing and reshape + x_msg_binned, x_features_binned, msk_f_binned = split_msk_and_msg( + bins_split, cmul, x_msg, x_node, msk, n_bins, self.bin_size + ) # Run the node-to-node kernel (distance computation / graph building / attention) dm = self.kernel(x_msg_binned, msk_f_binned, training=training) diff --git a/mlpf/pyg/inference.py b/mlpf/pyg/inference.py index 8621afdf8..3fcec5a23 100644 --- a/mlpf/pyg/inference.py +++ b/mlpf/pyg/inference.py @@ -42,11 +42,12 @@ def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, je for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)): if conv_type != "gravnet": X_pad, mask = torch_geometric.utils.to_dense_batch(batch.X, batch.batch) - batch_pad = Batch(X=X_pad, mask=mask) - ypred = model(batch_pad.to(rank)) + batch_pad = Batch(X=X_pad, mask=mask).to(rank) + ypred = model(batch_pad.X, batch_pad.mask) ypred = ypred[0][mask], ypred[1][mask], ypred[2][mask] else: - ypred = model(batch.to(rank)) + _batch = batch.to(rank) + ypred = model(_batch.X, _batch.batch) ygen = unpack_target(batch.ygen) ycand = unpack_target(batch.ycand) diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index 1d5de74f5..19fa63143 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -122,17 +122,17 @@ def __init__( # elementwise DNN for node charge regression, classes (-1, 0, 1) self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout) - def forward(self, event): - # unfold the Batch object - input_ = event.X.float() + def forward_batch(self, batched_events): + batch_or_mask = batched_events.batch if self.conv_type == "gravnet" else batched_events.mask + return self(batched_events.X, batch_or_mask) + + def forward(self, X_features, batch_or_mask): embeddings_id, embeddings_reg = [], [] if self.num_convs != 0: - + embedding = self.nn0(X_features) if self.conv_type == "gravnet": - embedding = self.nn0(input_) - - batch_idx = event.batch + batch_idx = batch_or_mask # perform a series of graph convolutions for num, conv in enumerate(self.conv_id): conv_input = embedding if num == 0 else embeddings_id[-1] @@ -141,8 +141,7 @@ def forward(self, event): conv_input = embedding if num == 0 else embeddings_reg[-1] embeddings_reg.append(conv(conv_input, batch_idx)) else: - mask = event.mask - embedding = self.nn0(input_) + mask = batch_or_mask for num, conv in enumerate(self.conv_id): conv_input = embedding if num == 0 else embeddings_id[-1] out_padded = conv(conv_input, ~mask) @@ -152,11 +151,11 @@ def forward(self, event): out_padded = conv(conv_input, ~mask) embeddings_reg.append(out_padded) - embedding_id = torch.cat([input_] + embeddings_id, axis=-1) + embedding_id = torch.cat([X_features] + embeddings_id, axis=-1) preds_id = self.nn_id(embedding_id) # regression - embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1) + embedding_reg = torch.cat([X_features] + embeddings_reg + [preds_id], axis=-1) # do some sanity checks on the PFElement input data # assert torch.all(torch.abs(input_[:, 3]) <= 1.0) # sin_phi @@ -166,10 +165,10 @@ def forward(self, event): # predict the 4-momentum, add it to the (pt, eta, sin phi, cos phi, E) of the input PFelement # the feature order is defined in fcc/postprocessing.py -> track_feature_order, cluster_feature_order - preds_pt = self.nn_pt(embedding_reg) + input_[..., 1:2] - preds_eta = self.nn_eta(embedding_reg) + input_[..., 2:3] - preds_phi = self.nn_phi(embedding_reg) + input_[..., 3:5] - preds_energy = self.nn_energy(embedding_reg) + input_[..., 5:6] + preds_pt = self.nn_pt(embedding_reg) + X_features[..., 1:2] + preds_eta = self.nn_eta(embedding_reg) + X_features[..., 2:3] + preds_phi = self.nn_phi(embedding_reg) + X_features[..., 3:5] + preds_energy = self.nn_energy(embedding_reg) + X_features[..., 5:6] preds_momentum = torch.cat([preds_pt, preds_eta, preds_phi, preds_energy], axis=-1) pred_charge = self.nn_charge(embedding_reg) diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index 7bde53574..ac046011f 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -155,7 +155,13 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train): ygen = unpack_target(batch.ygen) - ypred = model(batch) + if world_size > 1: + conv_type = model.module.conv_type + else: + conv_type = model.conv_type + + batchidx_or_mask = batch.batch if conv_type == "gravnet" else batch.mask + ypred = model(batch.X, batchidx_or_mask) ypred = unpack_predictions(ypred) if is_train: diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index 057b5c8cb..98788a322 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -238,21 +238,21 @@ def run(rank, world_size, config, args, outdir, logfile): if args.export_onnx: try: - dummy_features = torch.randn(256, model_kwargs["input_dim"], rank=rank) - dummy_batch = torch.zeros(256, dtype=torch.int64, rank=rank) + dummy_features = torch.randn(1, 640, model_kwargs["input_dim"], device=rank) + dummy_mask = torch.zeros(1, 640, dtype=torch.bool, device=rank) torch.onnx.export( model, - (dummy_features, dummy_batch), + (dummy_features, dummy_mask), "test.onnx", verbose=True, - input_names=["features", "batch"], + input_names=["features", "mask"], output_names=["id", "momentum", "charge"], dynamic_axes={ - "features": {0: "num_elements"}, - "batch": [0], - "id": [0], - "momentum": [0], - "charge": [0], + "features": {0: "num_batch", 1: "num_elements"}, + "mask": [0, 1], + "id": [0, 1], + "momentum": [0, 1], + "charge": [0, 1], }, ) except Exception as e: diff --git a/scripts/local_test_pyg.sh b/scripts/local_test_pyg.sh index 40d65ba04..6fe121651 100755 --- a/scripts/local_test_pyg.sh +++ b/scripts/local_test_pyg.sh @@ -28,3 +28,5 @@ mkdir -p experiments tfds build mlpf/heptfds/cms_pf/ttbar --manual_dir ./local_test_data python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots + +python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots --conv-type gnn_lsh --export-onnx diff --git a/scripts/tallinn/rtx/pytorch.sh b/scripts/tallinn/rtx/pytorch.sh index d8e9027dd..d0abdabce 100755 --- a/scripts/tallinn/rtx/pytorch.sh +++ b/scripts/tallinn/rtx/pytorch.sh @@ -1,15 +1,14 @@ #!/bin/bash #SBATCH --partition gpu -#SBATCH --gres gpu:rtx:2 +#SBATCH --gres gpu:rtx:8 #SBATCH --mem-per-gpu 40G #SBATCH -o logs/slurm-%x-%j-%N.out IMG=/home/software/singularity/pytorch.simg -cd ~/particleflow #TF training singularity exec -B /scratch/persistent --nv \ --env PYTHONPATH=hep_tfds \ - $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1 \ + $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1,2,3,4,5,6,7 \ --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pyg-cms-small.yaml \ - --train --conv-type gnn_lsh --num-epochs 10 --ntrain 500 --ntest 500 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10 + --train --test --make-plots --export-onnx --conv-type gnn_lsh --num-epochs 10 --ntrain 1000 --ntest 1000 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10 diff --git a/tests/test_torch_and_tf.py b/tests/test_torch_and_tf.py index cc174587f..8135f88df 100644 --- a/tests/test_torch_and_tf.py +++ b/tests/test_torch_and_tf.py @@ -48,6 +48,8 @@ def test_MessageBuildingLayerLSH(self): from mlpf.pyg.gnn_lsh import MessageBuildingLayerLSH as MessageBuildingLayerLSHTorch nn2 = MessageBuildingLayerLSHTorch(distance_dim=128, bin_size=64) + # for testing, use the slower TF-like stable sort + nn2.stable_sort = True x_dist = np.random.normal(size=(2, 256, 128)).astype(np.float32) x_node = np.random.normal(size=(2, 256, 32)).astype(np.float32)