Skip to content

Commit

Permalink
fix f16, transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jpata committed Apr 29, 2022
1 parent 3823800 commit c5b0f5f
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 118 deletions.
105 changes: 88 additions & 17 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,38 +937,105 @@ class KernelEncoder(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
from official.nlp.modeling.layers.kernel_attention import KernelAttention
self.key_dim = kwargs.pop("key_dim")
self.attn = KernelAttention(feature_transform="elu", num_heads=4, key_dim=self.key_dim, name=kwargs.get("name") + "_attention")
num_heads = 2

self.attn = KernelAttention(feature_transform="elu", num_heads=num_heads, key_dim=self.key_dim, name=kwargs.get("name") + "_attention")
self.ffn = point_wise_feed_forward_network(self.key_dim, self.key_dim, kwargs.get("name") + "_ffn", num_layers=1, activation="elu")
self.norm0 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0")
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1")
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0")
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1")
super(KernelEncoder, self).__init__(*args, **kwargs)

def call(self, args, training=False):
X, mask = args
X_attended = self.attn(X, X, training=training)
X = self.norm0(X + X_attended, training=training)
X = self.norm1(X + self.ffn(X), training=training)
return X
attn_output = self.attn(query=X, value=X, key=X, training=training)
out1 = self.norm1(X + attn_output)
ffn_output = self.ffn(out1)
out2 = self.norm2(out1 + ffn_output)
return out2

class KernelDecoder(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
from official.nlp.modeling.layers.kernel_attention import KernelAttention
self.key_dim = kwargs.pop("key_dim")
num_heads = 2

self.attn1 = KernelAttention(feature_transform="elu", num_heads=num_heads, key_dim=self.key_dim, name=kwargs.get("name") + "_attention1")
self.attn2 = KernelAttention(feature_transform="elu", num_heads=num_heads, key_dim=self.key_dim, name=kwargs.get("name") + "_attention2")

self.ffn = point_wise_feed_forward_network(self.key_dim, self.key_dim, kwargs.get("name") + "_ffn", num_layers=1, activation="elu")

self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln0")
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln1")
self.norm3 = tf.keras.layers.LayerNormalization(axis=-1, name=kwargs.get("name") + "_ln2")
super(KernelDecoder, self).__init__(*args, **kwargs)

def call(self, args, training=False):
X, enc_output, mask = args

attn1 = self.attn1(query=X, value=X, key=X, training=training)
out1 = self.norm1(attn1 + X, training=training)

attn2 = self.attn2(query=enc_output, value=enc_output, key=out1, training=training)
out2 = self.norm2(attn2 + out1)

ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
out3 = self.norm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)

return out3

class Transformer(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
self.encoders = []

key_dim = kwargs.pop("key_dim")
num_layers = kwargs.pop("num_layers")

for i in range(num_layers):
self.encoders.append(KernelEncoder(key_dim=key_dim, name="{}-enc{}".format(kwargs.get("name"), i)))

self.decoders = []
for i in range(num_layers):
self.decoders.append(KernelDecoder(key_dim=key_dim, name="{}-dec{}".format(kwargs.get("name"), i)))
super(Transformer, self).__init__(*args, **kwargs)

def call(self, inputs, training=False):
X, msk_input = inputs

for enc in self.encoders:
X = enc([X, msk_input], training=training)*msk_input

X_dec = X
for dec in self.decoders:
X_dec = dec([X_dec, X, msk_input], training=training)*msk_input

return X_dec


class PFNetTransformer(tf.keras.Model):
def __init__(self,
num_input_classes=8,
num_output_classes=3,
input_encoding="cms",
output_decoding={}):
schema="cms",
output_decoding={},
multi_output=True):
super(PFNetTransformer, self).__init__()

self.multi_output = multi_output

if input_encoding == "cms":
self.enc = InputEncodingCMS(num_input_classes)
elif input_encoding == "default":
self.enc = InputEncoding(num_input_classes)

key_dim = 128
key_dim = 64
self.ffn = point_wise_feed_forward_network(key_dim, key_dim, "ffn", num_layers=1, activation="elu")

self.encoders = []
for i in range(2):
self.encoders.append(KernelEncoder(key_dim=key_dim, name="enc{}".format(i)))
self.tf1 = Transformer(key_dim=key_dim, num_layers=2, name="tf1")
self.tf2 = Transformer(key_dim=key_dim, num_layers=2, name="tf2")

output_decoding["schema"] = schema
output_decoding["num_output_classes"] = num_output_classes
self.output_dec = OutputDecoding(**output_decoding)

def call(self, inputs, training=False):
Expand All @@ -980,10 +1047,14 @@ def call(self, inputs, training=False):
msk_input = tf.expand_dims(tf.cast(msk, tf.float32), -1)

X_enc = self.enc(X)

X_enc = self.ffn(X_enc)
for enc in self.encoders:
X_enc = enc([X_enc, msk_input], training=training)*msk_input
ret = self.output_dec([X, X_enc, X_enc, msk_input], training=training)

return ret
X_enc_1 = self.tf1([X_enc, msk_input], training=training)
X_enc_2 = self.tf2([X_enc, msk_input], training=training)

ret = self.output_dec([X, X_enc_1, X_enc_2, msk_input], training=training)

if self.multi_output:
return ret
else:
return tf.concat([ret["cls"], ret["charge"], ret["pt"], ret["eta"], ret["sin_phi"], ret["cos_phi"], ret["energy"]], axis=-1)
32 changes: 7 additions & 25 deletions mlpf/tfmodel/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def plot_event_visualization(self, epoch, outpath, ypred, ypred_id, msk, ievent=

#Plot the target particles
plt.axes(ax2)

msk = self.ytrue_id[ievent] != 0
eta = self.ytrue["eta"][ievent][msk]
sphi = self.ytrue["sin_phi"][ievent][msk]
Expand Down Expand Up @@ -266,47 +265,28 @@ def plot_corr(self, epoch, outpath, ypred, ypred_id, icls, reg_variable):
vals_pred = ypred[reg_variable][sel].flatten()
vals_true = self.ytrue[reg_variable][sel].flatten()

loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
loss_vals = loss(np.expand_dims(vals_true, -1), np.expand_dims(vals_pred, axis=-1)).numpy()

#save scatterplot of raw values
plt.figure()
plt.figure(figsize=(6,5))
bins = self.reg_bins[reg_variable]
if bins is None:
bins = 100
plt.hist2d(vals_true, vals_pred, bins=100, cmap="Blues")

plt.hist2d(vals_true, vals_pred, bins=(bins, bins), cmap="Blues", norm=matplotlib.colors.LogNorm(vmin=1))
plt.colorbar()

if len(vals_true) > 0:
minval = np.min(vals_true)
maxval = np.max(vals_true)
if not (math.isnan(minval) or math.isnan(maxval) or math.isinf(minval) or math.isinf(maxval)):
plt.plot([minval, maxval], [minval, maxval], color="black", ls="--", lw=0.5)
plt.xlabel("true")
plt.ylabel("predicted")
plt.title("{}, particle weighted, L={:.4f}".format(reg_variable, np.sum(loss_vals)))
plt.title(reg_variable)
image_path = str(outpath / "{}_cls{}_corr.png".format(reg_variable, icls))
plt.savefig(image_path, bbox_inches="tight")
if self.comet_experiment:
self.comet_experiment.log_image(image_path, step=epoch)
plt.close("all")

#save loss-weighted correlation histogram
plt.figure()
plt.hist2d(vals_true, vals_pred, bins=(bins, bins), weights=loss_vals, cmap="Blues")
plt.colorbar()
if len(vals_true) > 0:
minval = np.min(vals_true)
maxval = np.max(vals_true)
if not (math.isnan(minval) or math.isnan(maxval) or math.isinf(minval) or math.isinf(maxval)):
plt.plot([minval, maxval], [minval, maxval], color="black", ls="--", lw=0.5)
plt.xlabel("true")
plt.ylabel("predicted")
plt.title("{}, loss weighted, L={:.4f}".format(reg_variable, np.sum(loss_vals)))
image_path = str(outpath / "{}_cls{}_corr_weighted.png".format(reg_variable, icls))
plt.savefig(image_path, bbox_inches="tight")
if self.comet_experiment:
self.comet_experiment.log_image(image_path, step=epoch)

#Also plot the residuals, as we have the true and predicted values already available here
plt.figure()
residual = vals_true - vals_pred
Expand Down Expand Up @@ -604,8 +584,10 @@ def make_transformer(config, dtype):
kwargs[par] = config['parameters'][par]

model = PFNetTransformer(
multi_output=config["setup"]["multi_output"],
num_input_classes=config["dataset"]["num_input_classes"],
num_output_classes=config["dataset"]["num_output_classes"],
schema=config["dataset"]["schema"],
**kwargs
)
return model
Expand Down
1 change: 1 addition & 0 deletions mlpf/tfmodel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def load_and_interleave(dataset_names, config, num_gpus, split, batch_size):
steps = []
for ds_name in dataset_names:
ds, _ = get_heptfds_dataset(ds_name, config, num_gpus, split)
#ds = ds.take(500)
num_steps = ds.cardinality().numpy()
assert(num_steps > 0)
print("Loaded {}:{} with {} steps".format(ds_name, split, num_steps))
Expand Down
35 changes: 17 additions & 18 deletions parameters/cms-gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ dataset:
padded_num_elem_size: 12000
#(pt, eta, sin phi, cos phi, E)
num_momentum_outputs: 5
classification_loss_coef: 100.0
charge_loss_coef: 0.01
pt_loss_coef: 0.0001
eta_loss_coef: 100.0
sin_phi_loss_coef: 10.0
cos_phi_loss_coef: 10.0
energy_loss_coef: 0.0001
sum_energy_loss_coef: 0.00000001
sum_pt_loss_coef: 0.00000001
classification_loss_coef: 1.0
charge_loss_coef: 1.0
pt_loss_coef: 1.0
eta_loss_coef: 1.0
sin_phi_loss_coef: 1.0
cos_phi_loss_coef: 1.0
energy_loss_coef: 1.0
sum_energy_loss_coef: 0.0
sum_pt_loss_coef: 0.0
energy_loss:
type: Huber
pt_loss:
Expand Down Expand Up @@ -95,7 +95,7 @@ parameters:
bin_size: 100
max_num_bins: 200
distance_dim: 64
layernorm: no
layernorm: yes
dropout: 0.0
dist_activation: elu
ffn_dist_num_layers: 2
Expand All @@ -111,8 +111,8 @@ parameters:
activation: elu
normalize_degrees: yes
activation: elu
num_graph_layers_common: 2
num_graph_layers_energy: 2
num_graph_layers_common: 3
num_graph_layers_energy: 3
output_decoding:
activation: elu
regression_use_classification: yes
Expand All @@ -129,23 +129,22 @@ parameters:
phi_dim_decrease: yes
energy_dim_decrease: yes

id_hidden_dim: 512
id_hidden_dim: 256
charge_hidden_dim: 256
pt_hidden_dim: 256
eta_hidden_dim: 256
phi_hidden_dim: 256
energy_hidden_dim: 256

id_num_layers: 3
id_num_layers: 2
charge_num_layers: 2
pt_num_layers: 2
eta_num_layers: 2
phi_num_layers: 2
energy_num_layers: 3
energy_num_layers: 2
layernorm: yes
mask_reg_cls0: no

energy_multimodal: no
mask_reg_cls0: yes
energy_multimodal: yes

skip_connection: yes
debug: no
Expand Down
Loading

0 comments on commit c5b0f5f

Please sign in to comment.