-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph.py
166 lines (135 loc) · 5.89 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np
from collections import defaultdict
class Vertex(object):
vertex_counter = 0
def __init__(self, name, predicate=False, _from=None, _to=None):
self.name = name
self.predicate = predicate
self._from = _from
self._to = _to
self.id = Vertex.vertex_counter
Vertex.vertex_counter += 1
def __eq__(self, other):
if other is None:
return False
return self.__hash__() == other.__hash__()
def __hash__(self):
if self.predicate:
return hash((self.id, self._from, self._to, self.name))
else:
return hash(self.name)
def __lt__(self, other):
return self.name < other.name
class KnowledgeGraph(object):
def __init__(self):
self._vertices = set()
self._transition_matrix = defaultdict(set)
self._label_map = {}
self._inv_label_map = {}
def add_vertex(self, vertex):
"""Add a vertex to the Knowledge Graph."""
if vertex.predicate:
self._vertices.add(vertex)
else:
self._vertices.add(vertex)
def add_edge(self, v1, v2):
"""Add a uni-directional edge."""
self._transition_matrix[v1].add(v2)
def remove_edge(self, v1, v2):
"""Remove the edge v1 -> v2 if present."""
if v2 in self._transition_matrix[v1]:
self._transition_matrix[v1].remove(v2)
def get_neighbors(self, vertex):
"""Get all the neighbors of vertex (vertex -> neighbor)."""
return self._transition_matrix[vertex]
def visualise(self):
"""Visualise the graph using networkx & matplotlib."""
import matplotlib.pyplot as plt
import networkx as nx
nx_graph = nx.DiGraph()
for v in self._vertices:
if not v.predicate:
name = v.name.split('/')[-1]
nx_graph.add_node(name, name=name, pred=v.predicate)
for v in self._vertices:
if not v.predicate:
v_name = v.name.split('/')[-1]
# Neighbors are predicates
for pred in self.get_neighbors(v):
pred_name = pred.name.split('/')[-1]
for obj in self.get_neighbors(pred):
obj_name = obj.name.split('/')[-1]
nx_graph.add_edge(v_name, obj_name, name=pred_name)
plt.figure(figsize=(10,10))
_pos = nx.circular_layout(nx_graph)
nx.draw_networkx_nodes(nx_graph, pos=_pos)
nx.draw_networkx_edges(nx_graph, pos=_pos)
nx.draw_networkx_labels(nx_graph, pos=_pos)
names = nx.get_edge_attributes(nx_graph, 'name')
nx.draw_networkx_edge_labels(nx_graph, pos=_pos, edge_labels=names)
plt.show()
def _create_label(self, vertex, n):
"""Take labels of neighbors, sort them lexicographically and join."""
neighbor_names = [self._label_map[x][n - 1]
for x in self.get_neighbors(vertex)]
suffix = '-'.join(sorted(set(map(str, neighbor_names))))
return self._label_map[vertex][n - 1] + '-' + suffix
def weisfeiler_lehman(self, iterations=3):
"""Perform Weisfeiler-Lehman relabeling of the nodes."""
# The idea of using a hashing function is taken from:
# https://github.com/benedekrozemberczki/graph2vec
from hashlib import md5
# Store the WL labels in a dictionary with a two-level key:
# First level is the vertex identifier
# Second level is the WL iteration
self._label_map = defaultdict(dict)
self._inv_label_map = defaultdict(dict)
for v in self._vertices:
self._label_map[v][0] = v.name
self._inv_label_map[v.name][0] = v
for n in range(1, iterations+1):
for vertex in self._vertices:
# Create multi-set label
s_n = self._create_label(vertex, n)
# Store it in our label_map
self._label_map[vertex][n] = str(md5(s_n.encode()).digest())
for vertex in self._vertices:
for key, val in self._label_map[vertex].items():
self._inv_label_map[vertex][val] = key
def extract_random_walks(self, depth, root, max_walks=None):
"""Extract random walks of depth - 1 hops rooted in root."""
# Initialize one walk of length 1 (the root)
walks = {(root,)}
for i in range(depth):
# In each iteration, iterate over the walks, grab the
# last hop, get all its neighbors and extend the walks
walks_copy = walks.copy()
for walk in walks_copy:
node = walk[-1]
neighbors = self.get_neighbors(node)
if len(neighbors) > 0:
walks.remove(walk)
for neighbor in neighbors:
walks.add(walk + (neighbor, ))
# TODO: Should we prune in every iteration?
if max_walks is not None:
walks_ix = np.random.choice(range(len(walks)), replace=False,
size=min(len(walks), max_walks))
if len(walks_ix) > 0:
walks_list = list(walks)
walks = {walks_list[ix] for ix in walks_ix}
# Return a numpy array of these walks
return list(walks)
def rdflib_to_kg(rdflib_g, label_predicates=[]):
"""Convert a rdflib.Graph to our KnowledgeGraph."""
kg = KnowledgeGraph()
for (s, p, o) in rdflib_g:
if p not in label_predicates:
s_v, o_v = Vertex(str(s)), Vertex(str(o))
p_v = Vertex(str(p), predicate=True, _from=s_v, _to=o_v)
kg.add_vertex(s_v)
kg.add_vertex(p_v)
kg.add_vertex(o_v)
kg.add_edge(s_v, p_v)
kg.add_edge(p_v, o_v)
return kg