Skip to content

Commit

Permalink
Better CometML logging + Ray Train vs DDP comparison (jpata#278)
Browse files Browse the repository at this point in the history
* fix: update parameter files

* fix: better comet-ml logging

* update flatiron Ray Train submissions scripts

* update sbatch script

* log overridden config to comet-ml instead of original
  • Loading branch information
erwulff authored Dec 1, 2023
1 parent a172b3a commit 817aca4
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 100 deletions.
58 changes: 36 additions & 22 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import shutil
from datetime import datetime
import tqdm
import yaml

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -162,7 +163,7 @@ def train_and_valid(
"""

train_or_valid = "train" if is_train else "valid"
_logger.info(f"Initiating a {train_or_valid} run on device rank={rank}", color="red")
_logger.info(f"Initiating epoch #{epoch} {train_or_valid} run on device rank={rank}", color="red")

# this one will keep accumulating `train_loss` and then return the average
epoch_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0}
Expand All @@ -176,7 +177,9 @@ def train_and_valid(
if (world_size > 1) and (rank != 0):
iterator = enumerate(data_loader)
else:
iterator = tqdm.tqdm(enumerate(data_loader), total=len(data_loader), desc=f"{train_or_valid} loop on rank={rank}")
iterator = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch} {train_or_valid} loop on rank={rank}"
)

for itrain, batch in iterator:
batch = batch.to(rank, non_blocking=True)
Expand Down Expand Up @@ -207,7 +210,7 @@ def train_and_valid(
for loss_ in epoch_loss:
epoch_loss[loss_] += loss[loss_].detach()

if comet_experiment:
if comet_experiment and is_train:
if itrain % comet_step_freq == 0:
# this loss is not normalized to batch size
comet_experiment.log_metrics(loss, prefix=f"{train_or_valid}", step=(epoch - 1) * len(data_loader) + itrain)
Expand Down Expand Up @@ -286,8 +289,6 @@ def train_mlpf(
start_epoch = checkpoint["extra_state"]["epoch"] + 1

for epoch in range(start_epoch, num_epochs + 1):
if (rank == 0) or (rank == "cpu"):
_logger.info(f"Initiating epoch # {epoch}", color="bold")
t0 = time.time()

# training step
Expand All @@ -307,6 +308,11 @@ def train_mlpf(
rank, world_size, model, optimizer, valid_loader, False, comet_experiment, comet_step_freq, epoch
)

if comet_experiment:
comet_experiment.log_metrics(losses_t, prefix="epoch_train_loss", epoch=epoch)
comet_experiment.log_metrics(losses_v, prefix="epoch_valid_loss", epoch=epoch)
comet_experiment.log_epoch_end(epoch)

if (rank == 0) or (rank == "cpu"):
extra_state = {"epoch": epoch}
if losses_v["Total"] < best_val_loss:
Expand Down Expand Up @@ -381,7 +387,8 @@ def train_mlpf(
+ f"valid_loss={losses_v['Total']:.4f} "
+ f"stale={stale_epochs} "
+ f"time={round((t1-t0)/60, 2)}m "
+ f"eta={round(eta, 1)}m"
+ f"eta={round(eta, 1)}m",
color="bold",
)

for loss in losses_of_interest:
Expand Down Expand Up @@ -412,8 +419,6 @@ def train_mlpf(
if tensorboard_writer:
tensorboard_writer.flush()

if comet_experiment:
comet_experiment.log_epoch_end(epoch)
if world_size > 1:
dist.barrier()

Expand Down Expand Up @@ -492,14 +497,18 @@ def run(rank, world_size, config, args, outdir, logfile):
config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir
)
comet_experiment.set_name(f"rank_{rank}")
comet_experiment.log_parameter("run_id", outdir)
comet_experiment.log_parameter("run_id", Path(outdir).name)
comet_experiment.log_parameter("world_size", world_size)
comet_experiment.log_parameter("rank", rank)
comet_experiment.log_parameters(config, prefix="config:")
comet_experiment.set_model_graph(model)
comet_experiment.log_code("mlpf/pyg/training.py")
comet_experiment.log_code("mlpf/pyg_pipeline.py")
comet_experiment.log_code(args.config)
# save overridden config then log to comet
config_filename = "overridden_config.yaml"
with open((Path(outdir) / config_filename), "w") as file:
yaml.dump(config, file)
comet_experiment.log_code(str(Path(outdir) / config_filename))
else:
comet_experiment = None

Expand Down Expand Up @@ -634,18 +643,14 @@ def override_config(config, args):


def device_agnostic_run(config, args, world_size, outdir):
if args.train: # create a new outdir when training a model to never overwrite
if args.train:
logfile = f"{outdir}/train.log"
_configLogger("mlpf", filename=logfile)

os.system(f"cp {args.config} {outdir}/train-config.yaml")
else:
outdir = args.load
logfile = f"{outdir}/test.log"
_configLogger("mlpf", filename=logfile)

os.system(f"cp {args.config} {outdir}/test-config.yaml")

if config["gpus"]:
assert (
world_size <= torch.cuda.device_count()
Expand Down Expand Up @@ -687,6 +692,9 @@ def train_ray_trial(config, args, outdir=None):
world_rank = ray.train.get_context().get_world_rank()
world_size = ray.train.get_context().get_world_size()

# keep writing the logs
_configLogger("mlpf", filename=f"{outdir}/train.log")

model_kwargs = {
"input_dim": len(X_FEATURES[config["dataset"]]),
"num_classes": len(CLASS_LABELS[config["dataset"]]),
Expand Down Expand Up @@ -714,7 +722,7 @@ def train_ray_trial(config, args, outdir=None):
config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir
)
comet_experiment.set_name(f"world_rank_{world_rank}")
comet_experiment.log_parameter("run_id", outdir)
comet_experiment.log_parameter("run_id", Path(outdir).name)
comet_experiment.log_parameter("world_size", world_size)
comet_experiment.log_parameter("rank", rank)
comet_experiment.log_parameter("world_rank", world_rank)
Expand All @@ -723,7 +731,11 @@ def train_ray_trial(config, args, outdir=None):
comet_experiment.log_code(str(Path(outdir).parent.parent / "mlpf/pyg/training.py"))
comet_experiment.log_code(str(Path(outdir).parent.parent / "mlpf/pyg_pipeline.py"))
comet_experiment.log_code(str(Path(outdir).parent.parent / "mlpf/raytune/pt_search_space.py"))
comet_experiment.log_code(args.config)
# save overridden config then log to comet
config_filename = "overridden_config.yaml"
with open((Path(outdir) / config_filename), "w") as file:
yaml.dump(config, file)
comet_experiment.log_code(str(Path(outdir) / config_filename))
else:
comet_experiment = None

Expand Down Expand Up @@ -757,6 +769,8 @@ def run_ray_training(config, args, outdir):
if not args.local:
ray.init(address="auto")

_configLogger("mlpf", filename=f"{outdir}/train.log")

num_workers = args.gpus
scaling_config = ray.train.ScalingConfig(
num_workers=num_workers,
Expand Down Expand Up @@ -822,14 +836,14 @@ def run_hpo(config, args):

expdir = Path(config["raytune"]["local_dir"]) / name
expdir.mkdir(parents=True, exist_ok=True)
dirname = Path(config["raytune"]["local_dir"]) / name
shutil.copy(
"mlpf/raytune/search_space.py",
str(Path(config["raytune"]["local_dir"]) / name / "search_space.py"),
str(dirname / "search_space.py"),
) # Copy the search space definition file to the train dir for later reference
shutil.copy(
args.config,
str(Path(config["raytune"]["local_dir"]) / name / "config.yaml"),
) # Copy the config file to the train dir for later reference
# Save config for later reference. Note that saving happens after parameters are overwritten by cmd line args.
with open((dirname / "config.yaml"), "w") as file:
yaml.dump(config, file)

if not args.local:
ray.init(address="auto")
Expand Down
5 changes: 5 additions & 0 deletions mlpf/pyg_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def main():
prefix=(args.prefix or "") + Path(args.config).stem + "_",
experiments_dir=args.experiments_dir if args.experiments_dir else "experiments",
)
# Save config for later reference. Note that saving happens after parameters are overwritten by cmd line args.
config_filename = "train-config.yaml" if args.train else "test-config.yaml"
with open((Path(outdir) / config_filename), "w") as file:
yaml.dump(config, file)

if args.ray_train:
run_ray_training(config, args, outdir)
else:
Expand Down
2 changes: 1 addition & 1 deletion parameters/pyg-clic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: clic
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 2
Expand Down
5 changes: 4 additions & 1 deletion parameters/pyg-cms-physical.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: cms
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 2
Expand All @@ -15,6 +15,9 @@ nvalid: 500
num_workers: 0
prefetch_factor:
checkpoint_freq:
comet_name: particleflow-pt
comet_offline: False
comet_step_freq: 100

model:
gnn_lsh:
Expand Down
5 changes: 4 additions & 1 deletion parameters/pyg-cms-small-highqcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: cms
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 2
Expand All @@ -15,6 +15,9 @@ nvalid: 500
num_workers: 0
prefetch_factor:
checkpoint_freq:
comet_name: particleflow-pt
comet_offline: False
comet_step_freq: 10

model:
gnn_lsh:
Expand Down
2 changes: 1 addition & 1 deletion parameters/pyg-cms-small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: cms
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 10
Expand Down
2 changes: 1 addition & 1 deletion parameters/pyg-cms-test-qcdhighpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: cms
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 2
Expand Down
2 changes: 1 addition & 1 deletion parameters/pyg-cms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: cms
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 2
Expand Down
2 changes: 1 addition & 1 deletion parameters/pyg-delphes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ backend: pytorch

dataset: delphes
data_dir:
gpus: "0"
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 2
Expand Down
68 changes: 36 additions & 32 deletions scripts/flatiron/pt_raytrain_a100.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Walltime limit
#SBATCH -t 1:00:00
#SBATCH -N 1
#SBATCH -N 2
#SBATCH --exclusive
#SBATCH --tasks-per-node=1
#SBATCH -p gpu
Expand Down Expand Up @@ -35,42 +35,44 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
num_gpus=${SLURM_GPUS_PER_TASK} # gpus per compute node


################# DON NOT CHANGE THINGS HERE UNLESS YOU KNOW WHAT YOU ARE DOING ###############
redis_password=$(uuidgen)
export redis_password
echo "Redis password: ${redis_password}"
if [ "$SLURM_JOB_NUM_NODES" -gt 1 ]; then
################# DON NOT CHANGE THINGS HERE UNLESS YOU KNOW WHAT YOU ARE DOING ###############
redis_password=$(uuidgen)
export redis_password
echo "Redis password: ${redis_password}"

nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST) # Getting the node names
nodes_array=( $nodes )
nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST) # Getting the node names
nodes_array=( $nodes )

node_1=${nodes_array[0]}
ip=$(srun --nodes=1 --ntasks=1 -w $node_1 hostname --ip-address) # making redis-address
port=6379
ip_head=$ip:$port
export ip_head
echo "IP Head: $ip_head"
node_1=${nodes_array[0]}
ip=$(srun --nodes=1 --ntasks=1 -w $node_1 hostname --ip-address) # making redis-address
port=6379
ip_head=$ip:$port
export ip_head
echo "IP Head: $ip_head"

echo "STARTING HEAD at $node_1"
srun --nodes=1 --ntasks=1 -w $node_1 \
ray start --head --node-ip-address="$node_1" --port=$port \
--num-cpus $((SLURM_CPUS_PER_TASK)) --num-gpus $num_gpus --block &
echo "STARTING HEAD at $node_1"
srun --nodes=1 --ntasks=1 -w $node_1 \
ray start --head --node-ip-address="$node_1" --port=$port \
--num-cpus $((SLURM_CPUS_PER_TASK)) --num-gpus $num_gpus --block &

sleep 10
sleep 10

worker_num=$(($SLURM_JOB_NUM_NODES - 1)) #number of nodes other than the head node
for (( i=1; i<=$worker_num; i++ ))
do
node_i=${nodes_array[$i]}
echo "STARTING WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w $node_i \
ray start --address "$node_1":"$port" \
--num-cpus $((SLURM_CPUS_PER_TASK)) --num-gpus $num_gpus --block &
sleep 5
done
worker_num=$(($SLURM_JOB_NUM_NODES - 1)) #number of nodes other than the head node
for (( i=1; i<=$worker_num; i++ ))
do
node_i=${nodes_array[$i]}
echo "STARTING WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w $node_i \
ray start --address "$node_1":"$port" \
--num-cpus $((SLURM_CPUS_PER_TASK)) --num-gpus $num_gpus --block &
sleep 5
done

echo All Ray workers started.
##############################################################################################
# call your code below
echo All Ray workers started.
##############################################################################################
# call your code below
fi


echo 'Starting training.'
Expand All @@ -83,6 +85,8 @@ python3 -u mlpf/pyg_pipeline.py --train --ray-train \
--gpu-batch-multiplier 4 \
--num-workers 1 \
--prefetch-factor 2 \
--experiments-dir /mnt/ceph/users/ewulff/particleflow/experiments
--experiments-dir /mnt/ceph/users/ewulff/particleflow/experiments \
--local \
--comet

echo 'Training done.'
Loading

0 comments on commit 817aca4

Please sign in to comment.