Skip to content

Commit

Permalink
Merge pull request #78 from jpata/jpata_sept21
Browse files Browse the repository at this point in the history
LSH configurable, update gun sample training [TF]

Former-commit-id: 504cab7
  • Loading branch information
jpata authored Sep 3, 2021
2 parents ddd605b + fa7d31f commit 19c2746
Show file tree
Hide file tree
Showing 20 changed files with 2,112 additions and 289 deletions.
6 changes: 4 additions & 2 deletions clic/dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def pfParticleToDict(par):
"px": mom[0],
"py": mom[1],
"pz": mom[2],
"energy": par.getEnergy()
"energy": par.getEnergy(),
"charge": par.getCharge()
}
return vec

Expand Down Expand Up @@ -210,6 +211,7 @@ def caloHitToDict(par, calohit_to_cluster, genparticle_dict, calohit_recotosim):
nPF=colPF.getNumberOfElements()
nCl=colCl.getNumberOfElements()
nTr=colTr.getNumberOfElements()
nHit=simTrackHits.getNumberOfElements()
nHCB=colHCB.getNumberOfElements()
nHCE=colHCE.getNumberOfElements()
nECB=colECB.getNumberOfElements()
Expand All @@ -223,7 +225,7 @@ def caloHitToDict(par, calohit_to_cluster, genparticle_dict, calohit_recotosim):
assert(not (recohit in calohit_recotosim))
calohit_recotosim[recohit] = simhit

print "Event %d, nGen=%d, nPF=%d, nClusters=%d, nTracks=%d, nHCAL=%d, nECAL=%d" % (nEvent, nMc, nPF, nCl, nTr, nHCB+nHCE, nECB+nECE)
print "Event %d, nGen=%d, nPF=%d, nClusters=%d, nTracks=%d, nHCAL=%d, nECAL=%d, nHits=%d" % (nEvent, nMc, nPF, nCl, nTr, nHCB+nHCE, nECB+nECE, nHit)

genparticles = []
genparticle_dict = {}
Expand Down
36 changes: 24 additions & 12 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
if customize:
prefix += customize + "_"
config = customization_functions[customize](config)
#FIXME: refactor this
global_batch_size = config["setup"]["batch_size"]

if recreate or (weights is None):
outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node())
else:
outdir = str(Path(weights).parent)

# Decide tf.distribute.strategy depending on number of available GPUs
strategy, maybe_global_batch_size = get_strategy(global_batch_size)
Expand All @@ -166,10 +173,6 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,

X_val, ygen_val, ycand_val = prepare_val_data(config, dataset_def, single_file=False)

if recreate or (weights is None):
outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_", suffix=platform.node())
else:
outdir = str(Path(weights).parent)
if experiment:
experiment.set_name(outdir)
experiment.log_code("mlpf/tfmodel/model.py")
Expand All @@ -196,7 +199,11 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,

# Run model once to build the layers
print(X_val.shape)
model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]))

if config["tensorflow"]["eager"]:
model(X_val[:1])
else:
model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]))

initial_epoch = 0
if weights:
Expand All @@ -210,9 +217,10 @@ def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
configure_model_weights(model, config["setup"]["trainable"])
model.build((1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]))

print("trainable weights")
for w in model.trainable_weights:
print(w.name)
print("model weights")
tw_names = [m.name for m in model.trainable_weights]
for w in model.weights:
print("layer={} trainable={} shape={} num_weights={}".format(w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

loss_dict, loss_weights = get_loss_dict(config)
model.compile(
Expand Down Expand Up @@ -401,15 +409,18 @@ def find_lr(config, outdir, figname, logscale):


def customize_gun_sample(config):
config["dataset"]["padded_num_elem_size"] = 640

#FIXME: must be at least 2x bin_size
config["dataset"]["padded_num_elem_size"] = 1280

config["dataset"]["processed_path"] = "data/SinglePiFlatPt0p7To10_cfi/tfr_cand/*.tfrecords"
config["dataset"]["raw_path"] = "data/SinglePiFlatPt0p7To10_cfi/raw/*.pkl.bz2"
config["dataset"]["raw_path"] = "data/SinglePiFlatPt0p7To10_cfi/raw/*.pkl*"
config["dataset"]["classification_loss_coef"] = 0.0
config["dataset"]["charge_loss_coef"] = 0.0
config["dataset"]["eta_loss_coef"] = 0.0
config["dataset"]["sin_phi_loss_coef"] = 0.0
config["dataset"]["cos_phi_loss_coef"] = 0.0
config["setup"]["trainable"] = "ffn_energy"
config["setup"]["trainable"] = "regression"
config["setup"]["batch_size"] = 10*config["setup"]["batch_size"]
return config

Expand Down Expand Up @@ -471,6 +482,7 @@ def hypertune(config, outdir, ntrain, ntest, recreate):
config["dataset"]["num_output_classes"],
dataset_def,
)

callbacks.append(optim_callbacks)
callbacks.append(tf.keras.callbacks.EarlyStopping(patience=20, monitor='val_loss'))

Expand All @@ -486,7 +498,7 @@ def hypertune(config, outdir, ntrain, ntest, recreate):
#callbacks=[tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss')]
callbacks=callbacks,
)
print("Hyperparamter search complete.")
print("Hyperparameter search complete.")
shutil.copy(config_file_path, outdir + "/config.yaml") # Copy the config file to the train dir for later reference

tuner.results_summary()
Expand Down
14 changes: 14 additions & 0 deletions mlpf/tallinn/test-gnn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gpus 1
#SBATCH --mem-per-gpu=8G

IMG=/home/software/singularity/base.simg:latest
cd ~/particleflow

#TF training
singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-0l.yaml --plot-freq 10
singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-lsh-1l.yaml --plot-freq 10
singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-lsh-2l.yaml --plot-freq 10
singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-lsh-3l.yaml --plot-freq 10
singularity exec --nv $IMG python3 mlpf/pipeline.py train -c parameters/test-gnn/cms-nolsh-1l.yaml --plot-freq 10
2 changes: 1 addition & 1 deletion mlpf/tfmodel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def serialize_chunk(self, path, files, ichunk):
Xs = np.concatenate(Xs)
ys = np.concatenate(ys)

#set weights for each sample to be equal to the number of samples of this type
#set weights for each sample to be equal to the number of target particles of this type
#in the training script, this can be used to compute either inverse or class-balanced weights
uniq_vals, uniq_counts = np.unique(np.concatenate([y[:, 0] for y in ys]), return_counts=True)
for i in range(len(ys)):
Expand Down
131 changes: 109 additions & 22 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def pairwise_learnable_dist(A, B, ffn, training=False):
shp = tf.shape(A)

#stack node feature vectors of src[i], dst[j] into a matrix res[i,j] = (src[i], dst[j])
a, b, c, d = tf.meshgrid(tf.range(shp[0]), tf.range(shp[1]), tf.range(shp[2]), tf.range(shp[2]), indexing="ij")
inds1 = tf.stack([a,b,c], axis=-1)
inds2 = tf.stack([a,b,d], axis=-1)
mg = tf.meshgrid(tf.range(shp[0]), tf.range(shp[1]), tf.range(shp[2]), tf.range(shp[2]), indexing="ij")
inds1 = tf.stack([mg[0],mg[1],mg[2]], axis=-1)
inds2 = tf.stack([mg[0],mg[1],mg[3]], axis=-1)
res = tf.concat([
tf.gather_nd(A, inds1),
tf.gather_nd(B, inds2)], axis=-1
Expand Down Expand Up @@ -239,7 +239,13 @@ def __init__(self, *args, **kwargs):
elif self.aggregation_direction == "src":
self.agg_dim = -3

self.ffn = point_wise_feed_forward_network(self.output_dim, self.hidden_dim, num_layers=self.num_layers, activation=self.activation, name=kwargs.get("name")+"_ffn")
self.ffn = point_wise_feed_forward_network(
self.output_dim,
self.hidden_dim,
num_layers=self.num_layers,
activation=self.activation,
name=kwargs.get("name")+"_ffn"
)
super(NodeMessageLearnable, self).__init__(*args, **kwargs)

def call(self, inputs):
Expand Down Expand Up @@ -376,6 +382,7 @@ def call(self, x_msg, x_node, msk, training=False):
n_bins = tf.math.floordiv(n_points, self.bin_size)

#put each input item into a bin defined by the argmax output across the LSH embedding
#FIXME: this needs n_bins to be at least 2 to work correctly!
mul = tf.linalg.matmul(x_msg, self.codebook_random_rotations[:, :n_bins//2])
cmul = tf.concat([mul, -mul], axis=-1)
bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk)
Expand All @@ -392,6 +399,33 @@ def call(self, x_msg, x_node, msk, training=False):

return bins_split, x_features_binned, dm, msk_f_binned

class MessageBuildingLayerFull(tf.keras.layers.Layer):
def __init__(self, distance_dim=128, kernel=NodePairGaussianKernel(), **kwargs):
self.distance_dim = distance_dim
self.kernel = kernel

super(MessageBuildingLayerFull, self).__init__(**kwargs)

"""
x_msg: (n_batch, n_points, n_msg_features)
"""
def call(self, x_msg, msk, training=False):
msk_f = tf.expand_dims(tf.cast(msk, x_msg.dtype), -1)

shp = tf.shape(x_msg)
n_batches = shp[0]
n_points = shp[1]
n_message_features = shp[2]

#Run the node-to-node kernel (distance computation / graph building / attention)
dm = self.kernel(x_msg, training=training)

#remove the masked points row-wise and column-wise
dm = tf.einsum("bijk,bi->bijk", dm, tf.squeeze(msk_f, axis=-1))
dm = tf.einsum("bijk,bj->bijk", dm, tf.squeeze(msk_f, axis=-1))

return dm

class OutputDecoding(tf.keras.Model):
def __init__(self,
activation="elu",
Expand Down Expand Up @@ -601,6 +635,16 @@ def set_trainable_regression(self):
self.ffn_eta.trainable = False
self.ffn_pt.trainable = False
self.ffn_energy.trainable = True
self.ffn_energy_classwise.trainable = True

def set_trainable_classification(self):
self.ffn_id.trainable = True
self.ffn_charge.trainable = True
self.ffn_phi.trainable = False
self.ffn_eta.trainable = False
self.ffn_pt.trainable = False
self.ffn_energy.trainable = False
self.ffn_energy_classwise.trainable = False

class CombinedGraphLayer(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
Expand All @@ -614,7 +658,9 @@ def __init__(self, *args, **kwargs):
self.kernel = kwargs.pop("kernel")
self.node_message = kwargs.pop("node_message")
self.hidden_dim = kwargs.pop("hidden_dim")
self.do_lsh = kwargs.pop("do_lsh", True)
self.activation = getattr(tf.keras.activations, kwargs.pop("activation"))
self.dist_activation = getattr(tf.keras.activations, kwargs.pop("dist_activation", "linear"))

if self.do_layernorm:
self.layernorm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-6, name=kwargs.get("name")+"_layernorm")
Expand All @@ -627,12 +673,20 @@ def __init__(self, *args, **kwargs):
num_layers=2, activation=self.activation,
dropout=self.dropout
)
self.message_building_layer = MessageBuildingLayerLSH(
distance_dim=self.distance_dim,
max_num_bins=self.max_num_bins,
bin_size=self.bin_size,
kernel=build_kernel_from_conf(self.kernel, kwargs.get("name")+"_kernel")
)

if self.do_lsh:
self.message_building_layer = MessageBuildingLayerLSH(
distance_dim=self.distance_dim,
max_num_bins=self.max_num_bins,
bin_size=self.bin_size,
kernel=build_kernel_from_conf(self.kernel, kwargs.get("name")+"_kernel")
)
else:
self.message_building_layer = MessageBuildingLayerFull(
distance_dim=self.distance_dim,
kernel=build_kernel_from_conf(self.kernel, kwargs.get("name")+"_kernel")
)

self.message_passing_layers = [
get_message_layer(self.node_message, "{}_msg_{}".format(kwargs.get("name"), iconv)) for iconv in range(self.num_node_messages)
]
Expand All @@ -648,27 +702,39 @@ def call(self, x, msk, training=False):
x = self.layernorm(x, training=training)

#compute node features for graph building
x_dist = self.activation(self.ffn_dist(x, training=training))
x_dist = self.dist_activation(self.ffn_dist(x, training=training))

#x_dist = self.gaussian_noise(x_dist, training=training)

#compute the element-to-element messages / distance matrix / graph structure
bins_split, x_binned, dm, msk_binned = self.message_building_layer(x_dist, x, msk)
if self.do_lsh:
bins_split, x, dm, msk_f = self.message_building_layer(x_dist, x, msk)
else:
dm = self.message_building_layer(x_dist, msk)
msk_f = tf.expand_dims(tf.cast(msk, x.dtype), axis=-1)
bins_split = None

#run the node update with message passing
for msg in self.message_passing_layers:
x_binned = msg((x_binned, dm, msk_binned))
x = msg((x, dm, msk_f))

#x_binned = self.gaussian_noise(x_binned, training=training)
#x = self.gaussian_noise(x, training=training)

if self.dropout_layer:
x_binned = self.dropout_layer(x_binned, training=training)
x = self.dropout_layer(x, training=training)

x_enc = reverse_lsh(bins_split, x_binned)
#undo the binning according to the element-to-bin indices
if self.do_lsh:
x = reverse_lsh(bins_split, x)

return {"enc": x_enc, "dist": x_dist, "bins": bins_split, "dm": dm}
return {"enc": x, "dist": x_dist, "bins": bins_split, "dm": dm}

class PFNetDense(tf.keras.Model):
def __init__(self,
do_node_encoding=False,
hidden_dim=128,
dropout=0.0,
activation="gelu",
multi_output=False,
num_input_classes=8,
num_output_classes=3,
Expand All @@ -689,6 +755,21 @@ def __init__(self,
self.debug = debug

self.skip_connection = skip_connection

self.do_node_encoding = do_node_encoding
self.hidden_dim = hidden_dim
self.dropout = dropout
self.activation = getattr(tf.keras.activations, activation)

if self.do_node_encoding:
self.node_encoding = point_wise_feed_forward_network(
self.hidden_dim,
self.hidden_dim,
"node_encoding",
num_layers=2,
activation=self.activation,
dropout=self.dropout
)

if input_encoding == "cms":
self.enc = InputEncodingCMS(num_input_classes)
Expand All @@ -713,10 +794,13 @@ def call(self, inputs, training=False):
#encode the elements for classification (id)
enc = self.enc(X)

enc_cg = enc

encs = []
if self.skip_connection:
encs.append(enc)
enc_cg = enc
if self.do_node_encoding:
enc_cg = self.node_encoding(enc_cg, training=training)
for cg in self.cg:
enc_all = cg(enc_cg, msk, training=training)
enc_cg = enc_all["enc"]
Expand Down Expand Up @@ -790,17 +874,18 @@ def set_trainable_named(self, layer_names):
# msk2 = (pred_cls==icls)
# import matplotlib
# import matplotlib.pyplot as plt
# bins = np.linspace(0,6,100)

# plt.figure(figsize=(4,4))
# minval = np.min(y["energy"][msk1].numpy().flatten())
# maxval = np.max(y["energy"][msk1].numpy().flatten())
# plt.scatter(
# y["energy"][msk1&msk2].numpy().flatten(),
# y_pred["energy"][msk1&msk2].numpy().flatten(),
# marker=".", alpha=0.5
# )
# plt.xlabel("true")
# plt.ylabel("pred")
# plt.plot([0,6], [0,6])
# plt.plot([minval,maxval], [minval,maxval], color="black", ls="--", lw=1.0)
# plt.savefig("train_cls{}_{}.png".format(icls, self.step), bbox_inches="tight")
# plt.close("all")

Expand Down Expand Up @@ -832,20 +917,22 @@ def set_trainable_named(self, layer_names):
# msk2 = (pred_cls==icls)
# import matplotlib
# import matplotlib.pyplot as plt
# bins = np.linspace(0,6,100)

# plt.figure(figsize=(4,4))
# minval = np.min(y["energy"][msk1].numpy().flatten())
# maxval = np.max(y["energy"][msk1].numpy().flatten())
# plt.scatter(
# y["energy"][msk1&msk2].numpy().flatten(),
# y_pred["energy"][msk1&msk2].numpy().flatten(),
# marker=".", alpha=0.5
# )
# plt.xlabel("true")
# plt.ylabel("pred")
# plt.plot([0,6], [0,6])
# plt.plot([minval,maxval], [minval,maxval], color="black", ls="--", lw=1.0)
# plt.savefig("test_cls{}_{}.png".format(icls, self.step), bbox_inches="tight")
# plt.close("all")


# # Updates the metrics tracking the loss
# self.compiled_loss(y, y_pred, sample_weights, regularization_losses=self.losses)
# # Update the metrics.
Expand Down
Loading

0 comments on commit 19c2746

Please sign in to comment.