-
Notifications
You must be signed in to change notification settings - Fork 11
/
model.py
76 lines (52 loc) · 2.16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
from torch.nn import init
import math
class MODEL(nn.Module):
def __init__(self, K, num_classes, embed_dim, agg, prior):
super(MODEL, self).__init__()
"""
Initialize the model
:param K: the number of CONVOLUTION layers of the model
:param num_classes: number of classes (2 in our paper)
:param embed_dim: the output dimension of MLP layer
:agg: the inter-relation aggregator that output the final embedding
:lambad 1: the weight of MLP layer (ignore it)
:prior:prior
"""
self.agg = agg
#self.lambda_1 = lambda_1
self.K = K #how many layers
self.prior = prior
self.xent = nn.CrossEntropyLoss()
self.embed_dim = embed_dim
self.fun = nn.LeakyReLU(0.3)
self.weight_mlp = nn.Parameter(torch.FloatTensor(self.embed_dim, num_classes)) #Default requires_grad = True
self.weight_model = nn.Parameter(torch.FloatTensor((int(math.pow(2, K+1)-1) * self.embed_dim), 64))
self.weight_model2 = nn.Parameter(torch.FloatTensor(64, num_classes))
init.xavier_uniform_(self.weight_mlp)
init.xavier_uniform_(self.weight_model)
init.xavier_uniform_(self.weight_model2)
def forward(self, nodes, train_flag = True):
embedding = self.agg(nodes, train_flag)
scores_model = embedding.mm(self.weight_model)
scores_model = self.fun(scores_model)
scores_model = scores_model.mm(self.weight_model2)
#scores_model = self.fun(scores_model)
scores_mlp = embedding[:, 0: self.embed_dim].mm(self.weight_mlp)
scores_mlp = self.fun(scores_mlp)
return scores_model, scores_mlp
#dimension, the number of center nodes * 2
def to_prob(self, nodes, train_flag = False):
scores_model, scores_mlp = self.forward(nodes, train_flag)
scores_model = torch.sigmoid(scores_model)
return scores_model
def loss(self, nodes, labels, train_flag = True):
#the classification module
scores_model, scores_mlp = self.forward(nodes, train_flag)
scores_model = scores_model + torch.log(self.prior)
scores_mlp = scores_mlp + torch.log(self.prior)
loss_model = self.xent(scores_model, labels.squeeze())
#loss_mlp = self.xent(scores_mlp, labels.squeeze())
final_loss = loss_model #+ self.lambda_1 * loss_mlp
return final_loss