diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index de307b937..be90228c9 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -192,6 +192,9 @@ def run(rank, world_size, config, args, outdir, logfile): checkpoint_freq=config["checkpoint_freq"], ) + checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank)) + model, optimizer = load_checkpoint(checkpoint, model, optimizer) + if args.test: if config["load"] is None: # if we don't load, we must have a newly trained model @@ -200,9 +203,6 @@ def run(rank, world_size, config, args, outdir, logfile): else: outdir = config["load"] - checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank)) - model, optimizer = load_checkpoint(checkpoint, model, optimizer) - for type_ in config["test_dataset"][config["dataset"]]: # will be "physical", "gun" batch_size = config["test_dataset"][config["dataset"]][type_]["batch_size"] * config["gpu_batch_multiplier"] for sample in config["test_dataset"][config["dataset"]][type_]["samples"]: diff --git a/parameters/pyg-cms-physical.yaml b/parameters/pyg-cms-physical.yaml new file mode 100644 index 000000000..02ab7fca0 --- /dev/null +++ b/parameters/pyg-cms-physical.yaml @@ -0,0 +1,74 @@ +backend: pytorch + +dataset: cms +data_dir: +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +nvalid: 500 +num_workers: 0 +prefetch_factor: +checkpoint_freq: + +model: + gnn_lsh: + conv_type: gnn_lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 + + gravnet: + conv_type: gravnet + embedding_dim: 512 + width: 512 + num_convs: 3 + k: 16 + propagate_dimensions: 32 + space_dimensions: 4 + dropout: 0.0 + + attention: + conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + +train_dataset: + cms: + physical: + batch_size: 1 + samples: + cms_pf_ttbar: + version: 1.6.0 + cms_pf_qcd: + version: 1.6.0 + cms_pf_ztt: + version: 1.6.0 + cms_pf_qcd_high_pt: + version: 1.6.0 + cms_pf_sms_t1tttt: + version: 1.6.0 + +valid_dataset: + cms: + physical: + batch_size: 1 + samples: + cms_pf_qcd_high_pt: + version: 1.6.0 + +test_dataset: + cms: + physical: + batch_size: 1 + samples: + cms_pf_qcd_high_pt: + version: 1.6.0 diff --git a/parameters/pyg-cms-small-highqcd.yaml b/parameters/pyg-cms-small-highqcd.yaml new file mode 100644 index 000000000..8ea9b3708 --- /dev/null +++ b/parameters/pyg-cms-small-highqcd.yaml @@ -0,0 +1,70 @@ +backend: pytorch + +dataset: cms +data_dir: +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +nvalid: 500 +num_workers: 0 +prefetch_factor: +checkpoint_freq: + +model: + gnn_lsh: + conv_type: gnn_lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 + + gravnet: + conv_type: gravnet + embedding_dim: 512 + width: 512 + num_convs: 3 + k: 16 + propagate_dimensions: 32 + space_dimensions: 4 + dropout: 0.0 + + attention: + conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + +train_dataset: + cms: + physical: + batch_size: 1 + samples: + cms_pf_ttbar: + version: 1.6.0 + cms_pf_qcd: + version: 1.6.0 + cms_pf_qcd_high_pt: + version: 1.6.0 + +valid_dataset: + cms: + physical: + batch_size: 1 + samples: + cms_pf_qcd_high_pt: + version: 1.6.0 + +test_dataset: + cms: + physical: + batch_size: 1 + samples: + cms_pf_qcd_high_pt: + version: 1.6.0