Skip to content

Commit

Permalink
Add Ray Train training to GitHub actions CI/CD test (#314)
Browse files Browse the repository at this point in the history
* add Ray Train training to GitHub actions CI/CD test

* fix: merge issues

* fix: gpu resource allocation when not using gpus

* fix: path to datasets in Ray Train test

* fix: set use_cuda to False if gpus < 1

* fix: explicitly set rank="cpu" if use_cuda==False in Ray Train training

* update count parameters script

* save stats to file after every epoch

* chore: formatting

* fix: raytrain multi-gpu dataloader error
  • Loading branch information
erwulff authored Apr 25, 2024
1 parent 58a52b3 commit 53bfd54
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 35 deletions.
118 changes: 95 additions & 23 deletions mlpf/count_parameters.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,110 @@
import sys
import yaml
import argparse
import csv

import matplotlib.pyplot as plt

sys.path.append("../mlpf")

from pyg.mlpf import MLPF
from pyg.training import override_config
from pyg.utils import (
CLASS_LABELS,
X_FEATURES,
count_parameters,
ELEM_TYPES_NONZERO,
)

parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str, default=None, help="yaml config")
parser.add_argument(
"--attention-type",
type=str,
default=None,
help="attention type for self-attention layer",
choices=["math", "efficient", "flash"],
)
args = parser.parse_args()

with open(sys.argv[1], "r") as stream: # load config (includes: which physics samples, model params)
with open(args.config, "r") as stream: # load config (includes: which physics samples, model params)
config = yaml.safe_load(stream)

model_kwargs = {
"input_dim": len(X_FEATURES[config["dataset"]]),
"num_classes": len(CLASS_LABELS[config["dataset"]]),
"pt_mode": config["model"]["pt_mode"],
"eta_mode": config["model"]["eta_mode"],
"sin_phi_mode": config["model"]["sin_phi_mode"],
"cos_phi_mode": config["model"]["cos_phi_mode"],
"energy_mode": config["model"]["energy_mode"],
"attention_type": config["model"]["attention"]["attention_type"],
**config["model"][config["conv_type"]],
}
model = MLPF(**model_kwargs)

trainable_params, nontrainable_params, table = count_parameters(model)

print(table)

print("Model conv type:", model.conv_type)
print("conv_type HPs", config["model"][config["conv_type"]])
print("Trainable parameters:", trainable_params)
print("Non-trainable parameters:", nontrainable_params)
print("Total parameters:", trainable_params + nontrainable_params)

nconvs_width_list = [
(1, 32),
(1, 64),
(1, 128),
(1, 256),
(2, 32),
(2, 64),
(2, 128),
(2, 256),
(4, 32),
(4, 64),
(4, 128),
(4, 256),
]
summary = [["num_convs", "width", "Trainable parameters", "Non-trainable parameters", "Total parameters"]]

for nconvs, width in nconvs_width_list:
args.num_convs = nconvs
args.width = width
args.embedding_dim = width
args.test_datasets = []

override_config(config, args)

model_kwargs = {
"input_dim": len(X_FEATURES[config["dataset"]]),
"num_classes": len(CLASS_LABELS[config["dataset"]]),
"input_encoding": config["model"]["input_encoding"],
"pt_mode": config["model"]["pt_mode"],
"eta_mode": config["model"]["eta_mode"],
"sin_phi_mode": config["model"]["sin_phi_mode"],
"cos_phi_mode": config["model"]["cos_phi_mode"],
"energy_mode": config["model"]["energy_mode"],
"elemtypes_nonzero": ELEM_TYPES_NONZERO[config["dataset"]],
"learned_representation_mode": config["model"]["learned_representation_mode"],
**config["model"][config["conv_type"]],
}
model = MLPF(**model_kwargs)

trainable_params, nontrainable_params, table = count_parameters(model)

summary.append([nconvs, width, trainable_params, nontrainable_params, trainable_params + nontrainable_params])

# print(table)

print("Model conv type:", model.conv_type)
print("conv_type HPs", config["model"][config["conv_type"]])
print("Trainable parameters:", trainable_params)
print("Non-trainable parameters:", nontrainable_params)
print("Total parameters:", trainable_params + nontrainable_params)

# File path
file_path = "count_summary.csv"

# Writing to CSV file one row at a time
with open(file_path, mode="w", newline="") as file:
writer = csv.writer(file)
writer.writerows(summary)

nconvs_width_array = [(summary[ii][0], summary[ii][1]) for ii in range(1, len(summary))]
total_array = [summary[ii][4] for ii in range(1, len(summary))]
print(total_array)
plt.figure()
plt.scatter(total_array, total_array)

for ii, label in enumerate(nconvs_width_array):
plt.annotate(
"{}, {}".format(label[0], label[1]),
(total_array[ii], total_array[ii]),
textcoords="offset points",
xytext=(5, 5),
ha="center",
)

plt.yscale("log")
plt.xscale("log")
plt.savefig("count_plot.png")
9 changes: 5 additions & 4 deletions mlpf/pyg/PFDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,12 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray):
drop_last=True,
)

if use_ray:
import ray
# This doesn't seem to be needed anymore. 2024.04.17
# if use_ray:
# import ray

# prepare loader for distributed training, adds distributed sampler
loader = ray.train.torch.prepare_data_loader(loader)
# # prepare loader for distributed training, adds distributed sampler
# loader = ray.train.torch.prepare_data_loader(loader)

loaders[split].append(loader)

Expand Down
28 changes: 23 additions & 5 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tqdm
import yaml
import csv
import json

import numpy as np

Expand Down Expand Up @@ -597,6 +598,14 @@ def train_mlpf(
with open(f"{outdir}/mlpf_losses.pkl", "wb") as f:
pkl.dump(losses, f)

# save separate json files with stats for each epoch, this is robust to crashed-then-resumed trainings
history_path = Path(outdir) / "history"
history_path.mkdir(parents=True, exist_ok=True)
with open("{}/epoch_{}.json".format(str(history_path), epoch), "w") as fi:
stats = {"train": losses_t, "valid": losses_v}
stats["epoch_time"] = t1 - t0
json.dump(stats, fi)

if tensorboard_writer_train:
tensorboard_writer_train.flush()
if tensorboard_writer_valid:
Expand Down Expand Up @@ -903,15 +912,23 @@ def train_ray_trial(config, args, outdir=None):
if outdir is None:
outdir = ray.train.get_context().get_trial_dir()

use_cuda = True
use_cuda = args.gpus > 0

rank = ray.train.get_context().get_local_rank()
rank = ray.train.get_context().get_local_rank() if use_cuda else "cpu"
world_rank = ray.train.get_context().get_world_rank()
world_size = ray.train.get_context().get_world_size()

model_kwargs = {
"input_dim": len(X_FEATURES[config["dataset"]]),
"num_classes": len(CLASS_LABELS[config["dataset"]]),
"input_encoding": config["model"]["input_encoding"],
"pt_mode": config["model"]["pt_mode"],
"eta_mode": config["model"]["eta_mode"],
"sin_phi_mode": config["model"]["sin_phi_mode"],
"cos_phi_mode": config["model"]["cos_phi_mode"],
"energy_mode": config["model"]["energy_mode"],
"elemtypes_nonzero": ELEM_TYPES_NONZERO[config["dataset"]],
"learned_representation_mode": config["model"]["learned_representation_mode"],
**config["model"][config["conv_type"]],
}
model = MLPF(**model_kwargs)
Expand Down Expand Up @@ -1026,11 +1043,12 @@ def run_ray_training(config, args, outdir):

_configLogger("mlpf", filename=f"{outdir}/train.log")

num_workers = args.gpus
use_gpu = args.gpus > 0
num_workers = args.gpus if use_gpu else 1
scaling_config = ray.train.ScalingConfig(
num_workers=num_workers,
use_gpu=True,
resources_per_worker={"CPU": max(1, args.ray_cpus // num_workers - 1), "GPU": 1}, # -1 to avoid blocking
use_gpu=use_gpu,
resources_per_worker={"CPU": max(1, args.ray_cpus // num_workers - 1), "GPU": int(use_gpu)}, # -1 to avoid blocking
)
storage_path = Path(args.experiments_dir if args.experiments_dir else "experiments").resolve()
run_config = ray.train.RunConfig(
Expand Down
2 changes: 1 addition & 1 deletion parameters/pytorch/pyg-clic-hits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ model:
embedding_dim: 512
width: 512
num_convs: 2
dropout: 0.0
dropout_ff: 0.0
activation: "elu"
# gnn-lsh specific parameters
bin_size: 256
Expand Down
14 changes: 12 additions & 2 deletions scripts/local_test_pyg.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash
set -e
export TFDS_DATA_DIR=`pwd`/tensorflow_datasets
export PWD=`pwd`

rm -Rf local_test_data/TTbar_14TeV_TuneCUETP8M1_cfi

Expand Down Expand Up @@ -28,7 +29,16 @@ mkdir -p experiments
tfds build mlpf/heptfds/cms_pf/ttbar --manual_dir ./local_test_data

#test transformer with onnx export
python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type attention --export-onnx --pipeline --dtype float32 --attention-type math --num-convs 1
python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ \
--prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type attention \
--export-onnx --pipeline --dtype float32 --attention-type math --num-convs 1

#test GNN-LSH with onnx export
python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type gnn_lsh --export-onnx --pipeline --dtype float32 --num-convs 1
python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ./tensorflow_datasets/ \
--prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --test --make-plots --conv-type gnn_lsh \
--export-onnx --pipeline --dtype float32 --num-convs 1

#test Ray Train training
python mlpf/pyg_pipeline.py --config parameters/pytorch/pyg-cms.yaml --dataset cms --data-dir ${PWD}/tensorflow_datasets/ \
--prefix MLPF_test_ --num-epochs 2 --nvalid 1 --gpus 0 --train --ray-train --ray-cpus 2 --local --conv-type attention \
--pipeline --dtype float32 --attention-type math --num-convs 1 --experiments-dir ${PWD}/experiments

0 comments on commit 53bfd54

Please sign in to comment.