diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7fb27e281..e870ef5d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: sudo apt install python3 python3-pip wget sudo python3 -m pip install --upgrade pip sudo python3 -m pip install --upgrade setuptools - sudo python3 -m pip install tensorflow==2.4 setGPU sklearn matplotlib mplhep pandas scipy uproot3 uproot3-methods awkward0 keras-tuner networkx tensorflow-probability==0.12.2 tensorflow-addons==0.13.0 + sudo python3 -m pip install tensorflow==2.4 setGPU sklearn matplotlib mplhep pandas scipy uproot3 uproot3-methods awkward0 keras-tuner networkx tensorflow-probability==0.12.2 tensorflow-addons==0.13.0 tqdm - name: Run delphes TF model run: ./scripts/local_test_delphes_tf.sh @@ -31,10 +31,36 @@ jobs: sudo apt install python3 python3-pip wget sudo python3 -m pip install --upgrade pip sudo python3 -m pip install --upgrade setuptools - sudo python3 -m pip install tensorflow==2.4 setGPU sklearn matplotlib mplhep pandas scipy uproot3 uproot3-methods awkward0 keras-tuner networkx tensorflow-probability==0.12.2 tensorflow-addons==0.13.0 + sudo python3 -m pip install tensorflow==2.4 setGPU sklearn matplotlib mplhep pandas scipy uproot3 uproot3-methods awkward0 keras-tuner networkx tensorflow-probability==0.12.2 tensorflow-addons==0.13.0 tqdm - name: Run CMS TF model run: ./scripts/local_test_cms_tf.sh - + + delphes-pipeline: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install python deps + run: | + sudo apt install python3 python3-pip wget + sudo python3 -m pip install --upgrade pip + sudo python3 -m pip install --upgrade setuptools + sudo python3 -m pip install tensorflow==2.4 setGPU sklearn matplotlib mplhep pandas scipy uproot3 uproot3-methods awkward0 keras-tuner networkx tensorflow-probability==0.12.2 tensorflow-addons==0.13.0 tqdm click + - name: Run delphes TF model + run: ./scripts/local_test_delphes_pipeline.sh + + cms-pipeline: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install python deps + run: | + sudo apt install python3 python3-pip wget + sudo python3 -m pip install --upgrade pip + sudo python3 -m pip install --upgrade setuptools + sudo python3 -m pip install tensorflow==2.4 setGPU sklearn matplotlib mplhep pandas scipy uproot3 uproot3-methods awkward0 keras-tuner networkx tensorflow-probability==0.12.2 tensorflow-addons==0.13.0 tqdm click + - name: Run CMS TF model using the pipeline + run: ./scripts/local_test_cms_pipeline.sh + delphes-pytorch: runs-on: ubuntu-latest steps: diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py new file mode 100644 index 000000000..742e9b99b --- /dev/null +++ b/mlpf/pipeline.py @@ -0,0 +1,312 @@ +import sys +import os +import yaml +import json +import datetime +import glob +import random +import platform +import numpy as np +from pathlib import Path +import click +from tqdm import tqdm +import shutil + +import tensorflow as tf +from tensorflow.keras import mixed_precision +import tensorflow_addons as tfa + +from tfmodel.data import Dataset +from tfmodel.model_setup import ( + make_model, + configure_model_weights, + LearningRateLoggingCallback, + prepare_callbacks, + FlattenedCategoricalAccuracy, + eval_model, + freeze_model, +) + +from tfmodel.utils import ( + get_lr_schedule, + create_experiment_dir, + get_strategy, + make_weight_function, + load_config, + compute_weights_invsqrt, + compute_weights_none, + get_train_val_datasets, + targets_multi_output, + get_dataset_def, + prepare_val_data, + set_config_loss, + get_loss_dict, + parse_config, + get_best_checkpoint, + delete_all_but_best_checkpoint, +) + +from tfmodel.onecycle_scheduler import OneCycleScheduler, MomentumOneCycleScheduler +from tfmodel.lr_finder import LRFinder + + +@click.group() +@click.help_option("-h", "--help") +def main(): + pass + + +@main.command() +@click.help_option("-h", "--help") +@click.option("-c", "--config", help="configuration file", type=click.Path()) +@click.option("-w", "--weights", default=None, help="trained weights to load", type=click.Path()) +@click.option("--ntrain", default=None, help="override the number of training events", type=int) +@click.option("--ntest", default=None, help="override the number of testing events", type=int) +@click.option("-r", "--recreate", help="force creation of new experiment dir", is_flag=True) +@click.option("-p", "--prefix", default="", help="prefix to put at beginning of training dir name", type=str) +def train(config, weights, ntrain, ntest, recreate, prefix): + """Train a model defined by config""" + config_file_path = config + config, config_file_stem, global_batch_size, n_train, n_test, n_epochs, weights = parse_config( + config, ntrain, ntest, weights + ) + + dataset_def = get_dataset_def(config) + ds_train_r, ds_test_r, dataset_transform = get_train_val_datasets(config, global_batch_size, n_train, n_test) + X_val, ygen_val, ycand_val = prepare_val_data(config, dataset_def, single_file=False) + + if recreate or (weights is None): + outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node()) + else: + outdir = str(Path(weights).parent) + shutil.copy(config_file_path, outdir + "/config.yaml") # Copy the config file to the train dir for later reference + + # Decide tf.distribute.strategy depending on number of available GPUs + strategy, maybe_global_batch_size = get_strategy(global_batch_size) + # If using more than 1 GPU, we scale the batch size by the number of GPUs + if maybe_global_batch_size is not None: + global_batch_size = maybe_global_batch_size + total_steps = n_epochs * n_train // global_batch_size + lr = float(config["setup"]["lr"]) + + with strategy.scope(): + lr_schedule, optim_callbacks = get_lr_schedule(config, lr=lr, steps=total_steps) + opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule) + + if config["setup"]["dtype"] == "float16": + model_dtype = tf.dtypes.float16 + policy = mixed_precision.Policy("mixed_float16") + mixed_precision.set_global_policy(policy) + opt = mixed_precision.LossScaleOptimizer(opt) + else: + model_dtype = tf.dtypes.float32 + + model = make_model(config, model_dtype) + + # Run model once to build the layers + print(X_val.shape) + model(tf.cast(X_val[:1], model_dtype)) + + initial_epoch = 0 + if weights: + # We need to load the weights in the same trainable configuration as the model was set up + configure_model_weights(model, config["setup"].get("weights_config", "all")) + model.load_weights(weights, by_name=True) + initial_epoch = int(weights.split("/")[-1].split("-")[1]) + model(tf.cast(X_val[:1], model_dtype)) + + config = set_config_loss(config, config["setup"]["trainable"]) + configure_model_weights(model, config["setup"]["trainable"]) + model(tf.cast(X_val[:1], model_dtype)) + + loss_dict, loss_weights = get_loss_dict(config) + model.compile( + loss=loss_dict, + optimizer=opt, + sample_weight_mode="temporal", + loss_weights=loss_weights, + metrics={ + "cls": [ + FlattenedCategoricalAccuracy(name="acc_unweighted", dtype=tf.float64), + FlattenedCategoricalAccuracy(use_weights=True, name="acc_weighted", dtype=tf.float64), + ] + }, + ) + model.summary() + + callbacks = prepare_callbacks( + model, + outdir, + X_val[: config["setup"]["batch_size"]], + ycand_val[: config["setup"]["batch_size"]], + dataset_transform, + config["dataset"]["num_output_classes"], + ) + callbacks.append(optim_callbacks) + + fit_result = model.fit( + ds_train_r, + validation_data=ds_test_r, + epochs=initial_epoch + n_epochs, + callbacks=callbacks, + steps_per_epoch=n_train // global_batch_size, + validation_steps=n_test // global_batch_size, + initial_epoch=initial_epoch, + ) + history_path = Path(outdir) / "history" + history_path = str(history_path) + with open("{}/history.json".format(history_path), "w") as fi: + json.dump(fit_result.history, fi) + model.save(outdir + "/model_full", save_format="tf") + + print("Training done.") + + print("Starting evaluation...") + eval_dir = Path(outdir) / "evaluation" + eval_dir.mkdir() + eval_dir = str(eval_dir) + # TODO: change to use the evaluate() function below instead of eval_model() + eval_model(X_val, ygen_val, ycand_val, model, config, eval_dir, global_batch_size) + print("Evaluation done.") + + freeze_model(model, config, outdir) + + +@main.command() +@click.help_option("-h", "--help") +@click.option("-t", "--train_dir", required=True, help="directory containing a completed training", type=click.Path()) +@click.option("-c", "--config", help="configuration file", type=click.Path()) +@click.option("-w", "--weights", default=None, help="trained weights to load", type=click.Path()) +@click.option("-e", "--evaluation_dir", help="optionally specify evaluation output dir", type=click.Path()) +def evaluate(config, train_dir, weights, evaluation_dir): + """Evaluate the trained model in train_dir""" + if config is None: + config = Path(train_dir) / "config.yaml" + assert config.exists(), "Could not find config file in train_dir, please provide one with -c " + config, _, global_batch_size, _, _, _, weights = parse_config(config, weights=weights) + # Switch off multi-output for the evaluation for backwards compatibility + config["setup"]["multi_output"] = False + + if evaluation_dir is None: + eval_dir = str(Path(train_dir) / "evaluation") + else: + eval_dir = evaluation_dir + Path(eval_dir).mkdir(parents=True, exist_ok=True) + + if config["setup"]["dtype"] == "float16": + model_dtype = tf.dtypes.float16 + policy = mixed_precision.Policy("mixed_float16") + mixed_precision.set_global_policy(policy) + opt = mixed_precision.LossScaleOptimizer(opt) + else: + model_dtype = tf.dtypes.float32 + + dataset_def = get_dataset_def(config) + X_val, ygen_val, ycand_val = prepare_val_data(config, dataset_def, single_file=False) + + strategy, maybe_global_batch_size = get_strategy(global_batch_size) + if maybe_global_batch_size is not None: + global_batch_size = maybe_global_batch_size + + with strategy.scope(): + model = make_model(config, model_dtype) + + # Evaluate model once to build the layers + print(X_val.shape) + model(tf.cast(X_val[:1], model_dtype)) + + # need to load the weights in the same trainable configuration as the model was set up + configure_model_weights(model, config["setup"].get("weights_config", "all")) + if weights: + model.load_weights(weights, by_name=True) + else: + weights = get_best_checkpoint(train_dir) + print("Loading best weights that could be found from {}".format(weights)) + model.load_weights(weights, by_name=True) + model(tf.cast(X_val[:1], model_dtype)) + + model.compile() + eval_model(X_val, ygen_val, ycand_val, model, config, eval_dir, global_batch_size) + freeze_model(model, config, train_dir) + + +@main.command() +@click.help_option("-h", "--help") +@click.option("-c", "--config", help="configuration file", type=click.Path()) +@click.option("-o", "--outdir", help="output directory", type=click.Path(), default=".") +@click.option("-n", "--figname", help="name of saved figure", type=click.Path(), default="lr_finder.jpg") +@click.option("-l", "--logscale", help="use log scale on y-axis in figure", default=False, is_flag=True) +def find_lr(config, outdir, figname, logscale): + """Run the Learning Rate Finder to produce a batch loss vs. LR plot from + which an appropriate LR-range can be determined""" + config, _, global_batch_size, n_train, _, _, _ = parse_config(config) + ds_train_r, _, _ = get_train_val_datasets(config, global_batch_size, n_train, n_test=0) + + # Decide tf.distribute.strategy depending on number of available GPUs + strategy, maybe_global_batch_size = get_strategy(global_batch_size) + + # If using more than 1 GPU, we scale the batch size by the number of GPUs + if maybe_global_batch_size is not None: + global_batch_size = maybe_global_batch_size + + dataset_def = get_dataset_def(config) + X_val, _, _ = prepare_val_data(config, dataset_def, single_file=True) + + with strategy.scope(): + opt = tf.keras.optimizers.Adam(learning_rate=1e-7) # This learning rate will be changed by the lr_finder + if config["setup"]["dtype"] == "float16": + model_dtype = tf.dtypes.float16 + policy = mixed_precision.Policy("mixed_float16") + mixed_precision.set_global_policy(policy) + opt = mixed_precision.LossScaleOptimizer(opt) + else: + model_dtype = tf.dtypes.float32 + + model = make_model(config, model_dtype) + config = set_config_loss(config, config["setup"]["trainable"]) + + # Run model once to build the layers + model(tf.cast(X_val[:1], model_dtype)) + + configure_model_weights(model, config["setup"]["trainable"]) + + loss_dict, loss_weights = get_loss_dict(config) + model.compile( + loss=loss_dict, + optimizer=opt, + sample_weight_mode="temporal", + loss_weights=loss_weights, + metrics={ + "cls": [ + FlattenedCategoricalAccuracy(name="acc_unweighted", dtype=tf.float64), + FlattenedCategoricalAccuracy(use_weights=True, name="acc_weighted", dtype=tf.float64), + ] + }, + ) + model.summary() + + max_steps = 200 + lr_finder = LRFinder(max_steps=max_steps) + callbacks = [lr_finder] + + model.fit( + ds_train_r, + epochs=max_steps, + callbacks=callbacks, + steps_per_epoch=1, + ) + + lr_finder.plot(save_dir=outdir, figname=figname, log_scale=logscale) + + +@main.command() +@click.help_option("-h", "--help") +@click.option("-t", "--train_dir", help="training directory", type=click.Path()) +@click.option("-d", "--dry_run", help="do not delete anything", is_flag=True, default=False) +def delete_all_but_best_ckpt(train_dir, dry_run): + """Delete all checkpoint weights in /weights/ except the one with lowest loss in its filename.""" + delete_all_but_best_checkpoint(train_dir, dry_run) + + +if __name__ == "__main__": + main() diff --git a/mlpf/tfmodel/callbacks.py b/mlpf/tfmodel/callbacks.py new file mode 100644 index 000000000..6edfddcda --- /dev/null +++ b/mlpf/tfmodel/callbacks.py @@ -0,0 +1,76 @@ +import pickle +import tensorflow as tf +from tensorflow.keras.callbacks import TensorBoard +from tensorflow.keras.callbacks import ModelCheckpoint +from pathlib import Path +import numpy as np + + +class CustomTensorBoard(TensorBoard): + """ + Extends tensorflow.keras.callbacks TensorBoard + + Custom tensorboard class to make logging of learning rate possible when using + keras.optimizers.schedules.LearningRateSchedule. + See https://github.com/tensorflow/tensorflow/pull/37552 + + Also logs momemtum for supported optimizers that use momemtum. + """ + + def _collect_learning_rate(self, logs): + logs = logs or {} + lr_schedule = getattr(self.model.optimizer, "lr", None) + if isinstance(lr_schedule, tf.keras.optimizers.schedules.LearningRateSchedule): + logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(self.model.optimizer.iterations))) + else: + logs.update({"learning_rate": np.float64(tf.keras.backend.eval(self.model.optimizer.lr))}) + + # Log momentum if the optimizer has it + try: + logs.update({"momentum": np.float64(tf.keras.backend.eval(self.model.optimizer.momentum))}) + except AttributeError: + pass + + # In Adam, the momentum parameter is called beta_1 + if isinstance(self.model.optimizer, tf.keras.optimizers.Adam): + logs.update({"adam_beta_1": np.float64(tf.keras.backend.eval(self.model.optimizer.beta_1))}) + + return logs + + def on_epoch_end(self, epoch, logs): + logs = logs or {} + logs.update(self._collect_learning_rate(logs)) + super().on_epoch_end(epoch, logs) + + def on_train_batch_end(self, batch, logs): + logs = logs or {} + if isinstance(self.update_freq, int) and batch % self.update_freq == 0: + logs.update(self._collect_learning_rate(logs)) + super().on_train_batch_end(batch, logs) + + +class CustomModelCheckpoint(ModelCheckpoint): + """Extends tensorflow.keras.callbacks.ModelCheckpoint to also save optimizer""" + + def __init__(self, *args, **kwargs): + # Added arguments + self.optimizer_to_save = kwargs.pop("optimizer_to_save") + self.optimizer_filepath = kwargs.pop("optimizer_save_filepath") + super().__init__(*args, **kwargs) + + Path(self.filepath).parent.mkdir(parents=True, exist_ok=True) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + + # If a checkpoint was saved, also save the optimizer + filepath = str(self.optimizer_filepath).format(epoch=epoch + 1, **logs) + if self.epochs_since_last_save == 0: + if self.save_best_only: + current = logs.get(self.monitor) + if current == self.best: + with open(filepath, "wb") as f: + pickle.dump(self.optimizer_to_save, f) + else: + with open(filepath, "wb") as f: + pickle.dump(self.optimizer_to_save, f) diff --git a/mlpf/tfmodel/lr_finder.py b/mlpf/tfmodel/lr_finder.py new file mode 100644 index 000000000..152b69417 --- /dev/null +++ b/mlpf/tfmodel/lr_finder.py @@ -0,0 +1,71 @@ +import tensorflow as tf +from tensorflow.keras.callbacks import Callback +import matplotlib.pyplot as plt +from pathlib import Path + + +class LRFinder(Callback): + """`Callback` that exponentially adjusts the learning rate after each training batch between `start_lr` and + `end_lr` for a maximum number of batches: `max_step`. The loss and learning rate are recorded at each step allowing + visually finding a good learning rate as per https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html via + the `plot` method. + + A version of this learning rate finder technique is also described under the name 'LR range test' in Leslie Smith's + paper: https://arxiv.org/pdf/1803.09820.pdf. + """ + + def __init__(self, start_lr: float = 1e-7, end_lr: float = 3, max_steps: int = 200, smoothing=0.9): + super(LRFinder, self).__init__() + self.start_lr, self.end_lr = start_lr, end_lr + self.max_steps = max_steps + self.smoothing = smoothing + self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0 + self.lrs, self.losses = [], [] + + def on_train_begin(self, logs=None): + self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0 + self.lrs, self.losses = [], [] + + def on_train_batch_begin(self, batch, logs=None): + self.lr = self.exp_annealing(self.step) + tf.keras.backend.set_value(self.model.optimizer.lr, self.lr) + + def on_train_batch_end(self, batch, logs=None): + print("lr:", self.lr) + print("step", self.step) + logs = logs or {} + loss = logs.get("loss") + step = self.step + if loss: + print("loss", loss) + self.avg_loss = self.smoothing * self.avg_loss + (1 - self.smoothing) * loss + smooth_loss = self.avg_loss / (1 - self.smoothing ** (self.step + 1)) + self.losses.append(smooth_loss) + self.lrs.append(self.lr) + + if step == 0 or loss < self.best_loss: + self.best_loss = loss + + if smooth_loss > 4 * self.best_loss or tf.math.is_nan(smooth_loss): + self.model.stop_training = True + print("Loss reached predefined maximum... stopping") + if step >= self.max_steps: + print("STOPPING") + self.model.stop_training = True + self.step += 1 + + def exp_annealing(self, step): + return self.start_lr * (self.end_lr / self.start_lr) ** (step * 1.0 / self.max_steps) + + def plot(self, save_dir=None, figname="lr_finder.jpg", log_scale=False): + fig, ax = plt.subplots(1, 1) + ax.set_ylabel("Loss") + ax.set_xlabel("Learning Rate") + ax.set_xscale("log") + ax.xaxis.set_major_formatter(plt.FormatStrFormatter("%.0e")) + ax.plot(self.lrs, self.losses) + if log_scale: + ax.set_yscale("log") + if save_dir is not None: + Path(save_dir).mkdir(parents=True, exist_ok=True) + plt.savefig(str(Path(save_dir) / Path(figname))) diff --git a/mlpf/tfmodel/model_setup.py b/mlpf/tfmodel/model_setup.py index 0ce852551..e2a7a4af4 100644 --- a/mlpf/tfmodel/model_setup.py +++ b/mlpf/tfmodel/model_setup.py @@ -16,12 +16,17 @@ import matplotlib import matplotlib.pyplot as plt import sklearn -import kerastuner as kt from argparse import Namespace import time import json import random import platform +from tqdm import tqdm +from pathlib import Path +from tfmodel.onecycle_scheduler import OneCycleScheduler, MomentumOneCycleScheduler +from tfmodel.callbacks import CustomTensorBoard +from tfmodel.utils import get_lr_schedule, make_weight_function, targets_multi_output + from tensorflow.keras.metrics import Recall, CategoricalAccuracy @@ -180,8 +185,8 @@ def on_epoch_end(self, epoch, logs=None): def prepare_callbacks(model, outdir, X_val, y_val, dataset_transform, num_output_classes): callbacks = [] - tb = tf.keras.callbacks.TensorBoard( - log_dir=outdir, histogram_freq=1, write_graph=False, write_images=False, + tb = CustomTensorBoard( + log_dir=outdir + "/tensorboard_logs", histogram_freq=1, write_graph=False, write_images=False, update_freq='epoch', #profile_batch=(10,90), profile_batch=0, @@ -192,15 +197,20 @@ def prepare_callbacks(model, outdir, X_val, y_val, dataset_transform, num_output terminate_cb = tf.keras.callbacks.TerminateOnNaN() callbacks += [terminate_cb] + cp_dir = Path(outdir) / "weights" + cp_dir.mkdir(parents=True, exist_ok=True) cp_callback = tf.keras.callbacks.ModelCheckpoint( - filepath=outdir + "/weights-{epoch:02d}-{val_loss:.6f}.hdf5", + filepath=str(cp_dir / "weights-{epoch:02d}-{val_loss:.6f}.hdf5"), save_weights_only=True, verbose=0 ) cp_callback.set_model(model) callbacks += [cp_callback] - cb = CustomCallback(outdir, X_val, y_val, dataset_transform, num_output_classes) + history_path = Path(outdir) / "history" + history_path.mkdir(parents=True, exist_ok=True) + history_path = str(history_path) + cb = CustomCallback(history_path, X_val, y_val, dataset_transform, num_output_classes) cb.set_model(model) callbacks += [cb] @@ -220,48 +230,12 @@ def get_rundir(base='experiments'): logdir = 'run_%02d' % run_number return '{}/{}'.format(base, logdir) -def make_weight_function(config): - def weight_func(X,y,w): - - w_signal_only = tf.where(y[:, 0]==0, 0.0, 1.0) - w_signal_only *= tf.cast(X[:, 0]!=0, tf.float32) - - w_none = tf.ones_like(w) - w_none *= tf.cast(X[:, 0]!=0, tf.float32) - - w_invsqrt = tf.cast(tf.shape(w)[-1], tf.float32)/tf.sqrt(w) - w_invsqrt *= tf.cast(X[:, 0]!=0, tf.float32) - - weight_d = { - "none": w_none, - "signal_only": w_signal_only, - "inverse_sqrt": w_invsqrt - } - - ret_w = {} - for loss_component, weight_type in config["sample_weights"].items(): - ret_w[loss_component] = weight_d[weight_type] - - return X,y,ret_w - return weight_func def scale_outputs(X,y,w): ynew = y-out_m ynew = ynew/out_s return X, ynew, w -def targets_multi_output(num_output_classes): - def func(X, y, w): - return X, { - "cls": tf.one_hot(tf.cast(y[:, :, 0], tf.int32), num_output_classes), - "charge": y[:, :, 1:2], - "pt": y[:, :, 2:3], - "eta": y[:, :, 3:4], - "sin_phi": y[:, :, 4:5], - "cos_phi": y[:, :, 5:6], - "energy": y[:, :, 6:7], - }, w - return func def make_model(config, dtype): model = config['parameters']['model'] @@ -362,12 +336,15 @@ def make_dense(config, dtype): def eval_model(X, ygen, ycand, model, config, outdir, global_batch_size): import scipy - for ibatch in range(X.shape[0]//global_batch_size): + for ibatch in tqdm(range(X.shape[0]//global_batch_size), desc="Evaluating model"): nb1 = ibatch*global_batch_size nb2 = (ibatch+1)*global_batch_size y_pred = model.predict(X[nb1:nb2], batch_size=global_batch_size) - y_pred_raw_ids = y_pred[:, :, :config["dataset"]["num_output_classes"]] + if type(y_pred) is dict: # for e.g. when the model is multi_output + y_pred_raw_ids = y_pred['cls'] + else: + y_pred_raw_ids = y_pred[:, :, :config["dataset"]["num_output_classes"]] #softmax score must be over a threshold 0.6 to call it a particle (prefer low fake rate to high efficiency) # y_pred_id_sm = scipy.special.softmax(y_pred_raw_ids, axis=-1) @@ -382,7 +359,12 @@ def eval_model(X, ygen, ycand, model, config, outdir, global_batch_size): y_pred_id = np.argmax(y_pred_raw_ids, axis=-1) - y_pred_id = np.concatenate([np.expand_dims(y_pred_id, axis=-1), y_pred[:, :, config["dataset"]["num_output_classes"]:]], axis=-1) + if type(y_pred) is dict: + y_pred_rest = np.concatenate([y_pred["charge"], y_pred["pt"], y_pred["eta"], y_pred["sin_phi"], y_pred["cos_phi"], y_pred["energy"]], axis=-1) + y_pred_id = np.concatenate([np.expand_dims(y_pred_id, axis=-1), y_pred_rest], axis=-1) + else: + y_pred_id = np.concatenate([np.expand_dims(y_pred_id, axis=-1), y_pred[:, :, config["dataset"]["num_output_classes"]:]], axis=-1) + np_outfile = "{}/pred_batch{}.npz".format(outdir, ibatch) np.savez( np_outfile, @@ -599,7 +581,7 @@ def main(args, yaml_path, config): print("Output directory exists: {}".format(outdir), file=sys.stderr) sys.exit(1) else: - outdir = os.path.dirname(weights) + outdir = str(Path(weights).parent.parent) try: gpus = [int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")] @@ -614,8 +596,6 @@ def main(args, yaml_path, config): print("fallback to CPU", e) strategy = tf.distribute.OneDeviceStrategy("cpu") num_gpus = 0 - - actual_lr = global_batch_size*float(config['setup']['lr']) Xs = [] ygens = [] @@ -640,13 +620,10 @@ def main(args, yaml_path, config): ygen_val = np.concatenate(ygens) ycand_val = np.concatenate(ycands) + lr = float(config['setup']['lr']) with strategy.scope(): - lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( - actual_lr, - decay_steps=10000, - decay_rate=0.99, - staircase=True - ) + total_steps = n_epochs * n_train // global_batch_size + lr_schedule, optim_callbacks = get_lr_schedule(config, lr, steps=total_steps) opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule) if config['setup']['dtype'] == 'float16': model_dtype = tf.dtypes.float16 @@ -734,14 +711,16 @@ def main(args, yaml_path, config): model, outdir, X_val[:config['setup']['batch_size']], ycand_val[:config['setup']['batch_size']], dataset_transform, config["dataset"]["num_output_classes"] ) - callbacks.append(LearningRateLoggingCallback()) + callbacks.append(optim_callbacks) fit_result = model.fit( ds_train_r, validation_data=ds_test_r, epochs=initial_epoch+n_epochs, callbacks=callbacks, steps_per_epoch=n_train//global_batch_size, validation_steps=n_test//global_batch_size, initial_epoch=initial_epoch ) - with open("{}/history.json".format(outdir), "w") as fi: + history_path = Path(outdir) / "history" + history_path = str(history_path) + with open("{}/history.json".format(history_path), "w") as fi: json.dump(fit_result.history, fi) model.save(outdir + "/model_full", save_format="tf") diff --git a/mlpf/tfmodel/onecycle_scheduler.py b/mlpf/tfmodel/onecycle_scheduler.py new file mode 100644 index 000000000..fefc65c78 --- /dev/null +++ b/mlpf/tfmodel/onecycle_scheduler.py @@ -0,0 +1,144 @@ +import numpy as np +import logging + +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.keras.optimizers.schedules import LearningRateSchedule +from tensorflow.keras.callbacks import Callback + +logging.getLogger("tensorflow").setLevel(logging.ERROR) + + +class CosineAnnealer: + def __init__(self, start, end, steps): + self.start = start + self.end = end + self.steps = steps + self.n = 0 + + def step(self): + cos = np.cos(np.pi * (self.n / self.steps)) + 1 + self.n += 1 + return self.end + (self.start - self.end) / 2.0 * cos + + +class OneCycleScheduler(LearningRateSchedule): + """`LearningRateSchedule` that schedules the learning rate on a 1cycle policy as per Leslie Smith's paper + (https://arxiv.org/pdf/1803.09820.pdf). + + The implementation adopts additional improvements as per the fastai library: + https://docs.fast.ai/callbacks.one_cycle.html, where only two phases are used and the adaptation is done using + cosine annealing. In the warm-up phase the LR increases from `lr_max / div_factor` to `lr_max` and momentum + decreases from `mom_max` to `mom_min`. In the second phase the LR decreases from `lr_max` to `lr_max / final_div` + and momemtum from `mom_max` to `mom_min`. By default the phases are not of equal length, with the warm-up phase + controlled by the parameter `warmup_ratio`. + + NOTE: The momentum is not controlled through this class. This class is intended to be used together with the + `MomentumOneCycleScheduler` callback defined below. + """ + + def __init__( + self, + lr_max, + steps, + mom_min=0.85, + mom_max=0.95, + warmup_ratio=0.3, + div_factor=25.0, + final_div=100000.0, + name=None, + ): + super(OneCycleScheduler, self).__init__() + lr_min = lr_max / div_factor + + if final_div is None: + final_lr = lr_max / (div_factor * 1e4) + else: + final_lr = lr_max / (final_div) + + phase_1_steps = steps * warmup_ratio + phase_2_steps = steps - phase_1_steps + + self.lr_max = lr_max + self.steps = steps + self.mom_min = mom_min + self.mom_max = mom_max + self.warmup_ratio = warmup_ratio + self.div_factor = div_factor + self.final_div = final_div + self.name = name + + phases = [CosineAnnealer(lr_min, lr_max, phase_1_steps), CosineAnnealer(lr_max, final_lr, phase_2_steps)] + + step = 0 + phase = 0 + full_lr_schedule = np.zeros(int(steps)) + for ii in np.arange(np.floor(steps), dtype=int): + step += 1 + if step >= phase_1_steps: + phase = 1 + full_lr_schedule[ii] = phases[phase].step() + + self.full_lr_schedule = tf.convert_to_tensor(full_lr_schedule) + + def __call__(self, step): + with ops.name_scope(self.name or "OneCycleScheduler"): + return self.full_lr_schedule[tf.cast(step, "int32") - 1] + + def get_config(self): + return { + "lr_max": self.lr_max, + "steps": self.steps, + "mom_min": self.mom_min, + "mom_max": self.mom_max, + "warmup_ratio": self.warmup_ratio, + "div_factor": self.div_factor, + "final_div": self.final_div, + "name": self.name, + } + + +class MomentumOneCycleScheduler(Callback): + """`Callback` that schedules the momentun according to the 1cycle policy as per Leslie Smith's paper + (https://arxiv.org/pdf/1803.09820.pdf). + NOTE: This callback only schedules the momentum parameter, not the learning rate. It is intended to be used with the + KerasOneCycle learning rate scheduler above or similar. + """ + + def __init__(self, steps, mom_min=0.85, mom_max=0.95, warmup_ratio=0.3): + super(MomentumOneCycleScheduler, self).__init__() + + phase_1_steps = steps * warmup_ratio + phase_2_steps = steps - phase_1_steps + + self.phase_1_steps = phase_1_steps + self.phase_2_steps = phase_2_steps + self.phase = 0 + self.step = 0 + + self.phases = [CosineAnnealer(mom_max, mom_min, phase_1_steps), CosineAnnealer(mom_min, mom_max, phase_2_steps)] + + def on_train_begin(self, logs=None): + self.set_momentum(self.mom_schedule().step()) + + def on_train_batch_end(self, batch, logs=None): + self.step += 1 + if self.step >= self.phase_1_steps: + self.phase = 1 + + self.set_momentum(self.mom_schedule().step()) + + def set_momentum(self, mom): + # In Adam, the momentum parameter is called beta_1 + if isinstance(self.model.optimizer, tf.keras.optimizers.Adam): + tf.keras.backend.set_value(self.model.optimizer.beta_1, mom) + # In SDG, the momentum parameter is called momentum + elif isinstance(self.model.optimizer, tf.keras.optimizers.SGD): + tf.keras.backend.set_value(self.model.optimizer.momentum, mom) + else: + raise NotImplementedError( + "Only SGD and Adam are supported by MomentumOneCycleScheduler: {}".format(type(self.model.optimizer)) + ) + + def mom_schedule(self): + return self.phases[self.phase] diff --git a/mlpf/tfmodel/utils.py b/mlpf/tfmodel/utils.py new file mode 100644 index 000000000..909e8c2f3 --- /dev/null +++ b/mlpf/tfmodel/utils.py @@ -0,0 +1,334 @@ +import os +import yaml +from pathlib import Path +import datetime +import platform +import random +import glob +import numpy as np +from tqdm import tqdm +import re + +import tensorflow as tf +import tensorflow_addons as tfa + +from tfmodel.data import Dataset +from tfmodel.onecycle_scheduler import OneCycleScheduler, MomentumOneCycleScheduler + + +def load_config(config_file_path): + with open(config_file_path, "r") as ymlfile: + cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) + return cfg + + +def parse_config(config, ntrain=None, ntest=None, weights=None): + config_file_stem = Path(config).stem + config = load_config(config) + tf.config.run_functions_eagerly(config["tensorflow"]["eager"]) + global_batch_size = config["setup"]["batch_size"] + n_epochs = config["setup"]["num_epochs"] + if ntrain: + n_train = ntrain + else: + n_train = config["setup"]["num_events_train"] + if ntest: + n_test = ntest + else: + n_test = config["setup"]["num_events_test"] + + if "multi_output" not in config["setup"]: + config["setup"]["multi_output"] = True + + if weights is None: + weights = config["setup"]["weights"] + + return config, config_file_stem, global_batch_size, n_train, n_test, n_epochs, weights + + +def create_experiment_dir(prefix=None, suffix=None): + if prefix is None: + train_dir = Path("experiments") / datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + else: + train_dir = Path("experiments") / (prefix + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) + + if suffix is not None: + train_dir = train_dir.with_name(train_dir.name + "." + platform.node()) + + train_dir.mkdir(parents=True) + return str(train_dir) + + +def get_best_checkpoint(train_dir): + checkpoint_list = list(Path(Path(train_dir) / "weights").glob("weights*.hdf5")) + # Sort the checkpoints according to the loss in their filenames + checkpoint_list.sort(key=lambda x: float(re.search("\d+-\d+.\d+", str(x))[0].split("-")[-1])) + # Return the checkpoint with smallest loss + return str(checkpoint_list[0]) + + +def delete_all_but_best_checkpoint(train_dir, dry_run): + checkpoint_list = list(Path(Path(train_dir) / "weights").glob("weights*.hdf5")) + # Don't remove the checkpoint with smallest loss + if len(checkpoint_list) == 1: + raise UserWarning("There is only one checkpoint. No deletion was made.") + elif len(checkpoint_list) == 0: + raise UserWarning("Couldn't find ant checkpoints. No deletion was made.") + else: + # Sort the checkpoints according to the loss in their filenames + checkpoint_list.sort(key=lambda x: float(re.search("\d+-\d+.\d+", str(x))[0].split("-")[-1])) + best_ckpt = checkpoint_list.pop(0) + for ckpt in checkpoint_list: + if not dry_run: + ckpt.unlink() + + print("Removed all checkpoints in {} except {}".format(train_dir, best_ckpt)) + + +def get_strategy(global_batch_size): + try: + gpus = [int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")] + num_gpus = len(gpus) + print("num_gpus=", num_gpus) + if num_gpus > 1: + strategy = tf.distribute.MirroredStrategy() + global_batch_size = num_gpus * global_batch_size + else: + strategy = tf.distribute.OneDeviceStrategy("gpu:0") + except Exception as e: + print("fallback to CPU", e) + strategy = tf.distribute.OneDeviceStrategy("cpu") + num_gpus = 0 + return strategy, global_batch_size + + +def get_lr_schedule(config, lr, steps): + callbacks = [] + schedule = config["setup"]["lr_schedule"] + if schedule == "onecycle": + onecycle_cfg = config["onecycle"] + lr_schedule = OneCycleScheduler( + lr_max=lr, + steps=steps, + mom_min=onecycle_cfg["mom_min"], + mom_max=onecycle_cfg["mom_max"], + warmup_ratio=onecycle_cfg["warmup_ratio"], + div_factor=onecycle_cfg["div_factor"], + final_div=onecycle_cfg["final_div"], + ) + callbacks.append( + MomentumOneCycleScheduler( + steps=steps, + mom_min=onecycle_cfg["mom_min"], + mom_max=onecycle_cfg["mom_max"], + warmup_ratio=onecycle_cfg["warmup_ratio"], + ) + ) + elif schedule == "exponentialdecay": + lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + lr, + decay_steps=config["exponentialdecay"]["decay_steps"], + decay_rate=config["exponentialdecay"]["decay_rate"], + staircase=config["exponentialdecay"]["staircase"], + ) + else: + raise ValueError("Only supported LR schedules are 'exponentialdecay' and 'onecycle'.") + return lr_schedule, callbacks + + +def compute_weights_invsqrt(X, y, w): + wn = tf.cast(tf.shape(w)[-1], tf.float32) / tf.sqrt(w) + wn *= tf.cast(X[:, 0] != 0, tf.float32) + # wn /= tf.reduce_sum(wn) + return X, y, wn + + +def compute_weights_none(X, y, w): + wn = tf.ones_like(w) + wn *= tf.cast(X[:, 0] != 0, tf.float32) + return X, y, wn + + +def make_weight_function(config): + def weight_func(X,y,w): + + w_signal_only = tf.where(y[:, 0]==0, 0.0, 1.0) + w_signal_only *= tf.cast(X[:, 0]!=0, tf.float32) + + w_none = tf.ones_like(w) + w_none *= tf.cast(X[:, 0]!=0, tf.float32) + + w_invsqrt = tf.cast(tf.shape(w)[-1], tf.float32)/tf.sqrt(w) + w_invsqrt *= tf.cast(X[:, 0]!=0, tf.float32) + + weight_d = { + "none": w_none, + "signal_only": w_signal_only, + "inverse_sqrt": w_invsqrt + } + + ret_w = {} + for loss_component, weight_type in config["sample_weights"].items(): + ret_w[loss_component] = weight_d[weight_type] + + return X,y,ret_w + return weight_func + + +def targets_multi_output(num_output_classes): + def func(X, y, w): + return ( + X, + { + "cls": tf.one_hot(tf.cast(y[:, :, 0], tf.int32), num_output_classes), + "charge": y[:, :, 1:2], + "pt": y[:, :, 2:3], + "eta": y[:, :, 3:4], + "sin_phi": y[:, :, 4:5], + "cos_phi": y[:, :, 5:6], + "energy": y[:, :, 6:7], + }, + w, + ) + + return func + + +def get_dataset_def(config): + cds = config["dataset"] + + return Dataset( + num_input_features=int(cds["num_input_features"]), + num_output_features=int(cds["num_output_features"]), + padded_num_elem_size=int(cds["padded_num_elem_size"]), + raw_path=cds.get("raw_path", None), + raw_files=cds.get("raw_files", None), + processed_path=cds["processed_path"], + validation_file_path=cds["validation_file_path"], + schema=cds["schema"], + ) + + +def get_train_val_datasets(config, global_batch_size, n_train, n_test): + dataset_def = get_dataset_def(config) + + tfr_files = sorted(glob.glob(dataset_def.processed_path)) + if len(tfr_files) == 0: + raise Exception("Could not find any files in {}".format(dataset_def.processed_path)) + + random.shuffle(tfr_files) + dataset = tf.data.TFRecordDataset(tfr_files).map( + dataset_def.parse_tfr_element, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + + # Due to TFRecords format, the length of the dataset is not known beforehand + num_events = 0 + for _ in dataset: + num_events += 1 + print("dataset loaded, len={}".format(num_events)) + + weight_func = make_weight_function(config) + assert n_train + n_test <= num_events + + # Padded shapes + ps = ( + tf.TensorShape([dataset_def.padded_num_elem_size, dataset_def.num_input_features]), + tf.TensorShape([dataset_def.padded_num_elem_size, dataset_def.num_output_features]), + { + "cls": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + "charge": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + "energy": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + "pt": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + "eta": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + "sin_phi": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + "cos_phi": tf.TensorShape([dataset_def.padded_num_elem_size, ]), + } + ) + + ds_train = dataset.take(n_train).map(weight_func).padded_batch(global_batch_size, padded_shapes=ps) + ds_test = dataset.skip(n_train).take(n_test).map(weight_func).padded_batch(global_batch_size, padded_shapes=ps) + + if config["setup"]["multi_output"]: + dataset_transform = targets_multi_output(config["dataset"]["num_output_classes"]) + ds_train = ds_train.map(dataset_transform) + ds_test = ds_test.map(dataset_transform) + else: + dataset_transform = None + + ds_train_r = ds_train.repeat(config["setup"]["num_epochs"]) + ds_test_r = ds_test.repeat(config["setup"]["num_epochs"]) + + return ds_train_r, ds_test_r, dataset_transform + + +def prepare_val_data(config, dataset_def, single_file=False): + if single_file: + val_filelist = dataset_def.val_filelist[:1] + else: + val_filelist = dataset_def.val_filelist + if config["setup"]["num_val_files"] > 0: + val_filelist = val_filelist[: config["setup"]["num_val_files"]] + + Xs = [] + ygens = [] + ycands = [] + for fi in tqdm(val_filelist, desc="Preparing validation data"): + X, ygen, ycand = dataset_def.prepare_data(fi) + Xs.append(np.concatenate(X)) + ygens.append(np.concatenate(ygen)) + ycands.append(np.concatenate(ycand)) + + assert len(Xs) > 0, "Xs is empty" + X_val = np.concatenate(Xs) + ygen_val = np.concatenate(ygens) + ycand_val = np.concatenate(ycands) + + return X_val, ygen_val, ycand_val + + +def set_config_loss(config, trainable): + if trainable == "classification": + config["dataset"]["pt_loss_coef"] = 0.0 + config["dataset"]["eta_loss_coef"] = 0.0 + config["dataset"]["sin_phi_loss_coef"] = 0.0 + config["dataset"]["cos_phi_loss_coef"] = 0.0 + config["dataset"]["energy_loss_coef"] = 0.0 + elif trainable == "regression": + config["dataset"]["classification_loss_coef"] = 0.0 + config["dataset"]["charge_loss_coef"] = 0.0 + elif trainable == "all": + pass + return config + + +def get_class_loss(config): + if config["setup"]["classification_loss_type"] == "categorical_cross_entropy": + cls_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False) + elif config["setup"]["classification_loss_type"] == "sigmoid_focal_crossentropy": + cls_loss = tfa.losses.sigmoid_focal_crossentropy + else: + raise KeyError("Unknown classification loss type: {}".format(config["setup"]["classification_loss_type"])) + return cls_loss + + +def get_loss_dict(config): + cls_loss = get_class_loss(config) + loss_dict = { + "cls": cls_loss, + "charge": getattr(tf.keras.losses, config["dataset"].get("charge_loss", "MeanSquaredError"))(), + "pt": getattr(tf.keras.losses, config["dataset"].get("pt_loss", "MeanSquaredError"))(), + "eta": getattr(tf.keras.losses, config["dataset"].get("eta_loss", "MeanSquaredError"))(), + "sin_phi": getattr(tf.keras.losses, config["dataset"].get("sin_phi_loss", "MeanSquaredError"))(), + "cos_phi": getattr(tf.keras.losses, config["dataset"].get("cos_phi_loss", "MeanSquaredError"))(), + "energy": getattr(tf.keras.losses, config["dataset"].get("energy_loss", "MeanSquaredError"))(), + } + loss_weights = { + "cls": config["dataset"]["classification_loss_coef"], + "charge": config["dataset"]["charge_loss_coef"], + "pt": config["dataset"]["pt_loss_coef"], + "eta": config["dataset"]["eta_loss_coef"], + "sin_phi": config["dataset"]["sin_phi_loss_coef"], + "cos_phi": config["dataset"]["cos_phi_loss_coef"], + "energy": config["dataset"]["energy_loss_coef"], + } + return loss_dict, loss_weights diff --git a/parameters/cms-gnn-dense-big.yaml b/parameters/cms-gnn-dense-big.yaml index 3a03e8f20..aa3de4a6f 100644 --- a/parameters/cms-gnn-dense-big.yaml +++ b/parameters/cms-gnn-dense-big.yaml @@ -51,6 +51,7 @@ setup: sample_weights: inverse_sqrt trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense @@ -68,3 +69,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-gnn-dense-focal.yaml b/parameters/cms-gnn-dense-focal.yaml index 18cdf7231..5db4d2177 100644 --- a/parameters/cms-gnn-dense-focal.yaml +++ b/parameters/cms-gnn-dense-focal.yaml @@ -54,6 +54,7 @@ setup: focal_loss_alpha: 0.25 focal_loss_gamma: 3.0 focal_loss_from_logits: False + lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: cls: none @@ -83,3 +84,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-gnn-dense-onecycle.yaml b/parameters/cms-gnn-dense-onecycle.yaml new file mode 100644 index 000000000..ce6fcc2fb --- /dev/null +++ b/parameters/cms-gnn-dense-onecycle.yaml @@ -0,0 +1,97 @@ +backend: tensorflow + +dataset: + schema: cms + target_particles: cand + num_input_features: 15 + num_output_features: 7 +# NONE = 0, +# TRACK = 1, +# PS1 = 2, +# PS2 = 3, +# ECAL = 4, +# HCAL = 5, +# GSF = 6, +# BREM = 7, +# HFEM = 8, +# HFHAD = 9, +# SC = 10, +# HO = 11, + num_input_classes: 12 + #(none=0, ch.had=1, n.had=2, hfem=3, hfhad=4, gamma=5, e=6, mu=7) + num_output_classes: 8 + padded_num_elem_size: 6400 + #(pt, eta, sin phi, cos phi, E) + num_momentum_outputs: 5 + pt_loss: MeanSquaredLogarithmicError + energy_loss: MeanSquaredLogarithmicError + classification_loss_coef: 1.0 + charge_loss_coef: 0.1 + pt_loss_coef: 1.0 + eta_loss_coef: 0.1 + sin_phi_loss_coef: 1.0 + cos_phi_loss_coef: 1.0 + energy_loss_coef: 1.0 + raw_path: ../data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/*.pkl.bz2 + processed_path: ../data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr_cand/*.tfrecords + num_files_per_chunk: 1 + validation_file_path: ../data/TTbar_14TeV_TuneCUETP8M1_cfi/val/*.pkl.bz2 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: all + lr: 3e-4 + batch_size: 32 + num_events_train: 80000 + num_events_test: 10000 + num_epochs: 400 + num_val_files: 100 + dtype: float32 + trainable: all + classification_loss_type: categorical_cross_entropy # categorical_cross_entropy, sigmoid_focal_crossentropy + lr_schedule: onecycle # exponentialdecay, onecycle + +sample_weights: + cls: inverse_sqrt + charge: signal_only + pt: signal_only + eta: signal_only + sin_phi: signal_only + cos_phi: signal_only + energy: signal_only + +parameters: + model: gnn_dense + activation: elu + layernorm: no + hidden_dim: 256 + bin_size: 640 + clip_value_low: 0.0 + num_conv: 2 + num_gsl: 2 + normalize_degrees: yes + distance_dim: 128 + dropout: 0.0 + separate_momentum: yes + input_encoding: cms + debug: no + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes + +onecycle: + mom_min: 0.85 + mom_max: 0.95 + warmup_ratio: 0.3 + div_factor: 25.0 + final_div: 100000.0 \ No newline at end of file diff --git a/parameters/cms-gnn-dense-transfer.yaml b/parameters/cms-gnn-dense-transfer.yaml index e55cc9407..8b735f859 100644 --- a/parameters/cms-gnn-dense-transfer.yaml +++ b/parameters/cms-gnn-dense-transfer.yaml @@ -51,6 +51,7 @@ setup: sample_weights: inverse_sqrt trainable: transfer classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn_dense @@ -68,3 +69,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-gnn-dense.yaml b/parameters/cms-gnn-dense.yaml index dd351fa01..d74c0d530 100644 --- a/parameters/cms-gnn-dense.yaml +++ b/parameters/cms-gnn-dense.yaml @@ -51,6 +51,7 @@ setup: dtype: float32 trainable: classification classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: cls: inverse_sqrt @@ -80,3 +81,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-gnn-skipconn-v2.yaml b/parameters/cms-gnn-skipconn-v2.yaml index c13f7d854..e69919342 100644 --- a/parameters/cms-gnn-skipconn-v2.yaml +++ b/parameters/cms-gnn-skipconn-v2.yaml @@ -51,6 +51,7 @@ setup: sample_weights: inverse_sqrt trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn @@ -73,3 +74,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-gnn-skipconn.yaml b/parameters/cms-gnn-skipconn.yaml index f0c9aa51e..b1d2e50f0 100644 --- a/parameters/cms-gnn-skipconn.yaml +++ b/parameters/cms-gnn-skipconn.yaml @@ -51,6 +51,7 @@ setup: sample_weights: none trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: gnn @@ -73,3 +74,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-transformer-skipconn-gun.yaml b/parameters/cms-transformer-skipconn-gun.yaml index d079d71f2..f1fdd39e9 100644 --- a/parameters/cms-transformer-skipconn-gun.yaml +++ b/parameters/cms-transformer-skipconn-gun.yaml @@ -52,6 +52,7 @@ setup: sample_weights: inverse_sqrt trainable: all multi_output: yes + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: transformer @@ -66,3 +67,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/cms-transformer-skipconn.yaml b/parameters/cms-transformer-skipconn.yaml index 767f34416..0cb6eeb31 100644 --- a/parameters/cms-transformer-skipconn.yaml +++ b/parameters/cms-transformer-skipconn.yaml @@ -50,6 +50,7 @@ setup: sample_weights: none trainable: cls multi_output: yes + lr_schedule: exponentialdecay # exponentialdecay, onecycle parameters: model: transformer @@ -64,3 +65,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/delphes-gnn-skipconn-onecycle.yaml b/parameters/delphes-gnn-skipconn-onecycle.yaml new file mode 100644 index 000000000..16259d6b6 --- /dev/null +++ b/parameters/delphes-gnn-skipconn-onecycle.yaml @@ -0,0 +1,92 @@ +backend: tensorflow + +dataset: + schema: delphes + target_particles: gen + num_input_features: 12 + num_output_features: 7 + #(none=0, track=1, cluster=2) + num_input_classes: 3 + num_output_classes: 6 + num_momentum_outputs: 5 + padded_num_elem_size: 6400 + classification_loss_coef: 1.0 + momentum_loss_coef: 1.0 + charge_loss_coef: 1.0 + pt_loss_coef: 1.0 + eta_loss_coef: 1.0 + sin_phi_loss_coef: 1.0 + cos_phi_loss_coef: 1.0 + energy_loss_coef: 0.001 + momentum_loss_coefs: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 0.001 + raw_path: ../data/mlpf_zenodo/pythia8_ttbar/raw/*.pkl.bz2 + processed_path: ../data/mlpf_zenodo/pythia8_ttbar/tfr/*.tfrecords + num_files_per_chunk: 5 + validation_file_path: ../data/mlpf_zenodo/pythia8_qcd/val/*.pkl.bz2 + +tensorflow: + eager: no + +setup: + train: yes + weights: + weights_config: all + lr: 1e-5 + batch_size: 16 + num_events_train: 40000 + num_events_test: 5000 + num_epochs: 250 + num_val_files: -1 + dtype: float32 + trainable: all + multi_output: yes + classification_loss_type: categorical_cross_entropy + lr_schedule: onecycle # exponentialdecay, onecycle + +sample_weights: + cls: none + charge: none + pt: none + eta: none + sin_phi: none + cos_phi: none + energy: none + +parameters: + model: gnn + bin_size: 128 + num_convs_id: 2 + num_convs_reg: 2 + num_hidden_id_enc: 2 + num_hidden_id_dec: 2 + num_hidden_reg_enc: 2 + num_hidden_reg_dec: 2 + num_neighbors: 16 + hidden_dim_id: 256 + hidden_dim_reg: 256 + distance_dim: 256 + dropout: 0.2 + dist_mult: 1.0 + activation: elu + skip_connection: True + +timing: + num_ev: 100 + num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes + +onecycle: + mom_min: 0.85 + mom_max: 0.95 + warmup_ratio: 0.3 + div_factor: 25.0 + final_div: 100000.0 \ No newline at end of file diff --git a/parameters/delphes-gnn-skipconn.yaml b/parameters/delphes-gnn-skipconn.yaml index 88fd5f189..0f83160a2 100644 --- a/parameters/delphes-gnn-skipconn.yaml +++ b/parameters/delphes-gnn-skipconn.yaml @@ -37,10 +37,19 @@ setup: num_epochs: 400 num_val_files: -1 dtype: float32 - sample_weights: none trainable: all multi_output: no classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: none + charge: none + pt: none + eta: none + sin_phi: none + cos_phi: none + energy: none parameters: model: gnn @@ -63,3 +72,8 @@ parameters: timing: num_ev: 100 num_iter: 3 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/delphes-transformer-skipconn.yaml b/parameters/delphes-transformer-skipconn.yaml index f687fd63e..9874e5289 100644 --- a/parameters/delphes-transformer-skipconn.yaml +++ b/parameters/delphes-transformer-skipconn.yaml @@ -36,9 +36,18 @@ setup: num_epochs: 300 num_val_files: -1 dtype: float16 - sample_weights: none trainable: all multi_output: no + lr_schedule: exponentialdecay # exponentialdecay, onecycle + +sample_weights: + cls: none + charge: none + pt: none + eta: none + sin_phi: none + cos_phi: none + energy: none parameters: model: transformer @@ -49,3 +58,8 @@ parameters: support: 32 skip_connection: yes dropout: 0.2 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/test-cms-v2.yaml b/parameters/test-cms-v2.yaml index 483c948c3..3b14e661a 100644 --- a/parameters/test-cms-v2.yaml +++ b/parameters/test-cms-v2.yaml @@ -38,6 +38,7 @@ setup: dtype: float32 trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: cls: none @@ -67,3 +68,8 @@ parameters: timing: num_ev: 1 num_iter: 1 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/test-cms.yaml b/parameters/test-cms.yaml index bfb983b3a..a6e4f1967 100644 --- a/parameters/test-cms.yaml +++ b/parameters/test-cms.yaml @@ -38,6 +38,7 @@ setup: dtype: float32 trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: cls: none @@ -69,3 +70,8 @@ parameters: timing: num_ev: 1 num_iter: 1 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/parameters/test-delphes.yaml b/parameters/test-delphes.yaml index c80a31666..87c7208fe 100644 --- a/parameters/test-delphes.yaml +++ b/parameters/test-delphes.yaml @@ -37,6 +37,7 @@ setup: dtype: float32 trainable: all classification_loss_type: categorical_cross_entropy + lr_schedule: exponentialdecay # exponentialdecay, onecycle sample_weights: cls: none @@ -68,3 +69,8 @@ parameters: timing: num_ev: 1 num_iter: 1 + +exponentialdecay: + decay_steps: 10000 + decay_rate: 0.99 + staircase: yes diff --git a/scripts/local_test_cms_pipeline.sh b/scripts/local_test_cms_pipeline.sh new file mode 100755 index 000000000..2f10ec7e2 --- /dev/null +++ b/scripts/local_test_cms_pipeline.sh @@ -0,0 +1,46 @@ +#!/bin/bash +set -e + +rm -Rf data/TTbar_14TeV_TuneCUETP8M1_cfi + +mkdir -p data/TTbar_14TeV_TuneCUETP8M1_cfi/root +cd data/TTbar_14TeV_TuneCUETP8M1_cfi/root + +#Only CMS-internal use is permitted by CMS rules +wget -q --no-check-certificate -nc https://jpata.web.cern.ch/jpata/mlpf/cms/TTbar_14TeV_TuneCUETP8M1_cfi/root/pfntuple_1.root +wget -q --no-check-certificate -nc https://jpata.web.cern.ch/jpata/mlpf/cms/TTbar_14TeV_TuneCUETP8M1_cfi/root/pfntuple_2.root +wget -q --no-check-certificate -nc https://jpata.web.cern.ch/jpata/mlpf/cms/TTbar_14TeV_TuneCUETP8M1_cfi/root/pfntuple_3.root + +cd ../../.. + +#Create the ntuples +rm -Rf data/TTbar_14TeV_TuneCUETP8M1_cfi/raw +mkdir -p data/TTbar_14TeV_TuneCUETP8M1_cfi/raw +for file in `\ls -1 data/TTbar_14TeV_TuneCUETP8M1_cfi/root/*.root`; do + python3 mlpf/data/postprocessing2.py \ + --input $file \ + --outpath data/TTbar_14TeV_TuneCUETP8M1_cfi/raw \ + --save-normalized-table --events-per-file 5 +done + +#Set aside some data for validation +mkdir -p data/TTbar_14TeV_TuneCUETP8M1_cfi/val +mv data/TTbar_14TeV_TuneCUETP8M1_cfi/raw/pfntuple_3_0.pkl data/TTbar_14TeV_TuneCUETP8M1_cfi/val/ + +mkdir -p experiments +rm -Rf experiments/test-* + +#Run a simple training on a few events +rm -Rf data/TTbar_14TeV_TuneCUETP8M1_cfi/tfr +python3 mlpf/launcher.py --model-spec parameters/test-cms.yaml --action data + +#Run a simple training on a few events +python3 mlpf/pipeline.py train -c parameters/test-cms.yaml -p test-cms- + +#Generate the pred.npz file of predictions +python3 mlpf/pipeline.py evaluate -c parameters/test-cms.yaml -t ./experiments/test-cms-* + +python3 scripts/test_load_tfmodel.py ./experiments/test-cms-*/model_frozen/frozen_graph.pb + +python3 mlpf/pipeline.py train -c parameters/test-cms-v2.yaml -p test-cms-v2- +python3 mlpf/pipeline.py evaluate -c parameters/test-cms-v2.yaml -t ./experiments/test-cms-v2-* diff --git a/scripts/local_test_cms_tf.sh b/scripts/local_test_cms_tf.sh index 02b3cd9d6..877cc2595 100755 --- a/scripts/local_test_cms_tf.sh +++ b/scripts/local_test_cms_tf.sh @@ -38,9 +38,9 @@ python3 mlpf/launcher.py --model-spec parameters/test-cms.yaml --action data python3 mlpf/launcher.py --model-spec parameters/test-cms.yaml --action train #Generate the pred.npz file of predictions -python3 mlpf/launcher.py --model-spec parameters/test-cms.yaml --action eval --weights ./experiments/test-cms-*/weights-01-*.hdf5 +python3 mlpf/launcher.py --model-spec parameters/test-cms.yaml --action eval --weights ./experiments/test-cms-*/weights/weights-01-*.hdf5 python3 scripts/test_load_tfmodel.py ./experiments/test-cms-*/model_frozen/frozen_graph.pb python3 mlpf/launcher.py --model-spec parameters/test-cms-v2.yaml --action train -python3 mlpf/launcher.py --model-spec parameters/test-cms-v2.yaml --action eval --weights ./experiments/test-cms-v2-*/weights-01-*.hdf5 +python3 mlpf/launcher.py --model-spec parameters/test-cms-v2.yaml --action eval --weights ./experiments/test-cms-v2-*/weights/weights-01-*.hdf5 diff --git a/scripts/local_test_delphes_pipeline.sh b/scripts/local_test_delphes_pipeline.sh new file mode 100755 index 000000000..3117f8033 --- /dev/null +++ b/scripts/local_test_delphes_pipeline.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e + +mkdir -p data/pythia8_ttbar +mkdir -p data/pythia8_ttbar/val +cd data/pythia8_ttbar + +#download a test input file (you can also download everything from Zenodo at 10.5281/zenodo.4559324) +wget -q --no-check-certificate -nc https://zenodo.org/record/4559324/files/tev14_pythia8_ttbar_0_0.pkl.bz2 +wget -q --no-check-certificate -nc https://zenodo.org/record/4559324/files/tev14_pythia8_ttbar_0_1.pkl.bz2 +mv tev14_pythia8_ttbar_0_1.pkl.bz2 val/ + +cd ../.. + +mkdir -p experiments +rm -Rf experiments/test-* + +#Run a simple training on a few events +rm -Rf data/pythia8_ttbar/tfr +python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action data + +#Run a simple training on a few events +python3 mlpf/pipeline.py train -c parameters/test-delphes.yaml -p test-delphes- + +#Generate the pred.npz file of predictions +python3 mlpf/pipeline.py evaluate -c parameters/test-delphes.yaml -t ./experiments/test-delphes-* + +#Generate the timing file +python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action time --weights ./experiments/test-delphes-*/weights/weights-01-*.hdf5 + diff --git a/scripts/local_test_delphes_tf.sh b/scripts/local_test_delphes_tf.sh index ef8fb0117..bb41f072c 100755 --- a/scripts/local_test_delphes_tf.sh +++ b/scripts/local_test_delphes_tf.sh @@ -23,8 +23,8 @@ python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action data python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action train #Generate the pred.npz file of predictions -python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action eval --weights ./experiments/test-*/weights-01-*.hdf5 +python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action eval --weights ./experiments/test-*/weights/weights-01-*.hdf5 #Generate the timing file -python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action time --weights ./experiments/test-*/weights-01-*.hdf5 +python3 mlpf/launcher.py --model-spec parameters/test-delphes.yaml --action time --weights ./experiments/test-*/weights/weights-01-*.hdf5