-
Notifications
You must be signed in to change notification settings - Fork 37
/
dist_sampling_ogb_products_quiver.py
163 lines (129 loc) · 6.02 KB
/
dist_sampling_ogb_products_quiver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Reaches around 0.7870 ± 0.0036 test accuracy.
import os
import torch
from tqdm import tqdm
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborSampler
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import time
####################
# Import Quiver
####################
import quiver
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super(SAGE, self).__init__()
self.num_layers = num_layers
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
def forward(self, x, adjs):
# `train_loader` computes the k-hop neighborhood of a batch of nodes,
# and returns, for each layer, a bipartite graph object, holding the
# bipartite edges `edge_index`, the index `e_id` of the original edges,
# and the size/shape `size` of the bipartite graph.
# Target nodes are also included in the source nodes so that one can
# easily apply skip-connections or add self-loops.
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x.log_softmax(dim=-1)
def inference(self, x_all, device, subgraph_loader):
pbar = tqdm(total=x_all.size(0) * self.num_layers)
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
total_edges = 0
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
total_edges += edge_index.size(1)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
def run(rank, world_size, quiver_sampler, quiver_feature, y, edge_index, split_idx, num_features, num_classes):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]
train_loader = torch.utils.data.DataLoader(train_idx, batch_size=1024, pin_memory=True)
if rank == 0:
subgraph_loader = NeighborSampler(edge_index, node_idx=None,
sizes=[-1], batch_size=512,
shuffle=False, num_workers=6)
torch.manual_seed(12345)
model = SAGE(num_features, 256, num_classes, num_layers=3).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
y = y.to(rank)
for epoch in range(1, 21):
model.train()
epoch_start = time.time()
for seeds in train_loader:
n_id, batch_size, adjs = quiver_sampler.sample(seeds)
adjs = [adj.to(rank) for adj in adjs]
optimizer.zero_grad()
out = model(quiver_feature[n_id], adjs)
loss = F.nll_loss(out, y[n_id[:batch_size]])
loss.backward()
optimizer.step()
dist.barrier()
if rank == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {time.time() - epoch_start}')
if rank == 0 and epoch % 5 == 0: # We evaluate on a single GPU for now
model.eval()
with torch.no_grad():
out = model.module.inference(quiver_feature, rank, subgraph_loader)
res = out.argmax(dim=-1) == y.cpu()
acc1 = int(res[train_idx].sum()) / train_idx.numel()
acc2 = int(res[val_idx].sum()) / val_idx.numel()
acc3 = int(res[test_idx].sum()) / test_idx.numel()
print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}')
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
root = "/data/products"
dataset = PygNodePropPredDataset('ogbn-products', root)
data = dataset[0]
split_idx = dataset.get_idx_split()
world_size = torch.cuda.device_count()
##############################
# Create Sampler And Feature
##############################
csr_topo = quiver.CSRTopo(data.edge_index)
quiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, [15, 10, 5], 0, mode="GPU")
feature = torch.zeros(data.x.shape)
feature[:] = data.x
quiver_feature = quiver.Feature(rank=0, device_list=list(range(world_size)), device_cache_size="2G", cache_policy="device_replicate", csr_topo=csr_topo)
quiver_feature.from_cpu_tensor(feature)
print('Let\'s use', world_size, 'GPUs!')
mp.spawn(
run,
args=(world_size, quiver_sampler, quiver_feature, data.y.squeeze(), data.edge_index, split_idx, dataset.num_features, dataset.num_classes),
nprocs=world_size,
join=True
)