-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
dna.py
102 lines (82 loc) · 3.24 KB
/
dna.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import DNAConv
dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset)
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = None
def gen_uniform_20_20_60_split(data):
skf = StratifiedKFold(5, shuffle=True, random_state=55)
idx = [torch.from_numpy(i) for _, i in skf.split(data.y, data.y)]
data.train_idx = idx[0].to(torch.long)
data.val_idx = idx[1].to(torch.long)
data.test_idx = torch.cat(idx[2:], dim=0).to(torch.long)
return data
data = gen_uniform_20_20_60_split(data)
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
heads=1, groups=1):
super().__init__()
self.hidden_channels = hidden_channels
self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
self.convs = torch.nn.ModuleList()
for i in range(num_layers):
self.convs.append(
DNAConv(hidden_channels, heads, groups, dropout=0.8))
self.lin2 = torch.nn.Linear(hidden_channels, out_channels)
def reset_parameters(self):
self.lin1.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
self.lin2.reset_parameters()
def forward(self, x, edge_index):
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x_all = x.view(-1, 1, self.hidden_channels)
for conv in self.convs:
x = F.relu(conv(x_all, edge_index))
x = x.view(-1, 1, self.hidden_channels)
x_all = torch.cat([x_all, x], dim=1)
x = x_all[:, -1]
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return torch.log_softmax(x, dim=1)
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
model = Net(in_channels=dataset.num_features, hidden_channels=128,
out_channels=dataset.num_classes, num_layers=5, heads=8, groups=16)
model, data = model.to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_idx], data.y[data.train_idx])
loss.backward()
optimizer.step()
@torch.no_grad()
def test():
model.eval()
out, accs = model(data.x, data.edge_index), []
for _, idx in data('train_idx', 'val_idx', 'test_idx'):
pred = out[idx].argmax(1)
acc = pred.eq(data.y[idx]).sum().item() / idx.numel()
accs.append(acc)
return accs
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train()
train_acc, val_acc, tmp_test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')