Skip to content

Commit

Permalink
Learning rate schedules and Mamba layer (jpata#282)
Browse files Browse the repository at this point in the history
* fix: update parameter files

* fix: better comet-ml logging

* update flatiron Ray Train submissions scripts

* update sbatch script

* log overridden config to comet-ml instead of original

* fix: checkpoint loading

specify full path to checkpoint using --load-cehckpoint

* feat: implement LR schedules in the PyTorch training code

* update sbatch scripts

* feat: LR schedules support checkpointing and resuming training

* update sbatch scripts

* update ray tune search space

* fix: dropout parameter not taking effect on torch gnn-lsh model

* make more gnn-lsh parameters confgiurable

* make activation function configurable

* update raytune search space

* feat: add MambaLayer

* update raytune search space

* update pyg-cms.yaml

* fix loading of checkpoint in testing with raytrain based run
  • Loading branch information
erwulff authored Dec 11, 2023
1 parent ea1c15f commit 56fbf57
Show file tree
Hide file tree
Showing 21 changed files with 558 additions and 124 deletions.
78 changes: 66 additions & 12 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,32 @@ def forward(self, x, mask):
return x


class MambaLayer(nn.Module):
def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1, d_state=16, d_conv=4, expand=2):
super(MambaLayer, self).__init__()
self.act = nn.ELU
from mamba_ssm import Mamba

self.mamba = Mamba(
d_model=embedding_dim,
d_state=d_state,
d_conv=d_conv,
expand=expand,
)
self.norm0 = torch.nn.LayerNorm(embedding_dim)
self.seq = torch.nn.Sequential(
nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act()
)
self.dropout = torch.nn.Dropout(dropout)

def forward(self, x, mask):
x = self.mamba(x)
x = self.norm0(x + self.seq(x))
x = self.dropout(x)
x = x * (~mask.unsqueeze(-1))
return x


def ffn(input_dim, output_dim, width, act, dropout):
return nn.Sequential(
nn.Linear(input_dim, width),
Expand All @@ -59,22 +85,45 @@ def __init__(
embedding_dim=128,
width=128,
num_convs=2,
dropout=0.0,
activation="elu",
# gravnet specific parameters
k=32,
propagate_dimensions=32,
space_dimensions=4,
dropout=0.4,
conv_type="gravnet",
# gnn-lsh specific parameters
bin_size=640,
max_num_bins=200,
distance_dim=128,
layernorm=True,
num_node_messages=2,
ffn_dist_hidden_dim=128,
# self-attention specific parameters
num_heads=2,
# mamba specific parameters
d_state=16,
d_conv=4,
expand=2,
):
super(MLPF, self).__init__()

self.conv_type = conv_type

self.act = nn.ELU
if activation == "elu":
self.act = nn.ELU
elif activation == "relu":
self.act = nn.ReLU
elif activation == "relu6":
self.act = nn.ReLU6
elif activation == "leakyrelu":
self.act = nn.LeakyReLU

self.dropout = dropout
self.input_dim = input_dim
self.num_convs = num_convs

self.bin_size = 640
self.bin_size = bin_size

# embedding of the inputs
if num_convs != 0:
Expand All @@ -89,21 +138,27 @@ def __init__(
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))
self.conv_id.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout))
self.conv_reg.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout))
elif self.conv_type == "mamba":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(MambaLayer(embedding_dim, num_heads, width, dropout, d_state, d_conv, expand))
self.conv_reg.append(MambaLayer(embedding_dim, num_heads, width, dropout, d_state, d_conv, expand))
elif self.conv_type == "gnn_lsh":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
gnn_conf = {
"inout_dim": embedding_dim,
"bin_size": self.bin_size,
"max_num_bins": 200,
"distance_dim": 128,
"layernorm": True,
"num_node_messages": 2,
"dropout": 0.0,
"ffn_dist_hidden_dim": 128,
"max_num_bins": max_num_bins,
"distance_dim": distance_dim,
"layernorm": layernorm,
"num_node_messages": num_node_messages,
"dropout": dropout,
"ffn_dist_hidden_dim": ffn_dist_hidden_dim,
}
self.conv_id.append(CombinedGraphLayer(**gnn_conf))
self.conv_reg.append(CombinedGraphLayer(**gnn_conf))
Expand All @@ -123,7 +178,6 @@ def __init__(
self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout)

def forward(self, X_features, batch_or_mask):

embeddings_id, embeddings_reg = [], []
if self.num_convs != 0:
embedding = self.nn0(X_features)
Expand Down
Loading

0 comments on commit 56fbf57

Please sign in to comment.