-
Notifications
You must be signed in to change notification settings - Fork 3
/
scopes.py
46 lines (33 loc) · 1.1 KB
/
scopes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from collections import deque
from collections import defaultdict
from spn.linked.nodes import SumNode
from spn.linked.nodes import ProductNode
from spn.linked.nodes import CategoricalIndicatorNode
from spn.linked.layers import CategoricalIndicatorLayer
from spn.linked.layers import SumLayer
from spn.linked.layers import ProductLayer
from spn.linked.spn import Spn as LinkedSpn
import numpy
import itertools
def topological_layer_sort(layers):
"""
layers is a sequence of layers
"""
#
#
layers_dict = {layer: layer.input_layers for layer in layers}
sorted_layers = []
while layers_dict:
acyclic = False
temp_layers_dict = dict(layers_dict)
for layer, descendants in temp_layers_dict.items():
for desc_layer in descendants:
if desc_layer in layers_dict:
break
else:
acyclic = True
del layers_dict[layer]
sorted_layers.append(layer)
if not acyclic:
raise RuntimeError("A cyclic dependency occurred")
return sorted_layers