Skip to content

Commit

Permalink
Implment get_as_torch_geometric
Browse files Browse the repository at this point in the history
  • Loading branch information
mewim committed Jan 25, 2023
1 parent f2b3dba commit 277112a
Show file tree
Hide file tree
Showing 12 changed files with 761 additions and 16 deletions.
1 change: 1 addition & 0 deletions dataset/tensor-list/copy.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
COPY tensor FROM "dataset/tensor-list/vTensor.csv" (HEADER=true);
1 change: 1 addition & 0 deletions dataset/tensor-list/schema.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));
7 changes: 7 additions & 0 deletions dataset/tensor-list/vTensor.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ID,boolTensor,doubleTensor,intTensor,oneDimInt
0,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[1,2],[3,4]],[[5,6],[7,8]]]",1
3,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[3,4],[5,6]],[[7,8],[9,10]]]",2
4,"[false,true]","[[0.4,0.8],[0.7,0.6]]","[[[5,6],[7,8]],[[9,10],[11,12]]]",
5,"[true,true]","[[0.4,0.9],[0.5,0.2]]","[[[7,8],[9,10]],[[11,12],[13,14]]]",
6,"[false,true]","[[0.2,0.4],[0.5,0.1]]","[[[9,10],[11,12]],[[13,14],[15,16]]]",5
8,"[false,true]","[[0.6,0.4],[0.6,0.1]]","[[[11,12],[13,14]],[[15,16],[17,18]]]",6
6 changes: 6 additions & 0 deletions tools/python_api/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ pandas
networkx~=3.0.0
numpy
pyarrow~=10.0.0
torch~=1.13.1
torch-cluster~=1.6.0
torch-geometric~=2.2.0
torch-scatter~=2.1.0
torch-sparse~=0.6.16
torch-spline-conv~=1.2.1
2 changes: 2 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class PyConnection {

void setMaxNumThreadForExec(uint64_t numThreads);

py::str getNodePropertyNames(const string& tableName);

private:
unordered_map<string, shared_ptr<Value>> transformPythonParameters(py::list params);

Expand Down
7 changes: 6 additions & 1 deletion tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ void PyConnection::initialize(py::handle& m) {
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"));
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"));
PyDateTime_IMPORT;
}

Expand Down Expand Up @@ -37,6 +38,10 @@ void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
conn->setMaxNumThreadForExec(numThreads);
}

py::str PyConnection::getNodePropertyNames(const string& tableName) {
return conn->getNodePropertyNames(tableName);
}

unordered_map<string, shared_ptr<Value>> PyConnection::transformPythonParameters(py::list params) {
unordered_map<string, shared_ptr<Value>> result;
for (auto param : params) {
Expand Down
44 changes: 30 additions & 14 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
NODE_TYPE = "NODE"
REL_TYPE = "REL"
from .torch_geometric_result_converter import TorchGeometricResultConverter
from .types import Type


class QueryResult:
Expand Down Expand Up @@ -65,16 +65,7 @@ def get_as_networkx(self, directed=True):
nx_graph = nx.DiGraph()
else:
nx_graph = nx.Graph()
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [NODE_TYPE, REL_TYPE]:
column_to_extract[i] = (column_type, column_name)
column_to_extract = self._get_columns_to_extract()

self.reset_iterator()

Expand All @@ -87,12 +78,12 @@ def get_as_networkx(self, directed=True):
row = self.get_next()
for i in column_to_extract:
column_type, _ = column_to_extract[i]
if column_type == NODE_TYPE:
if column_type == Type.NODE.value:
_id = row[i]["_id"]
nodes[(_id["table"], _id["offset"])] = row[i]
table_to_label_dict[_id["table"]] = row[i]["_label"]

elif column_type == REL_TYPE:
elif column_type == Type.REL.value:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
Expand All @@ -115,3 +106,28 @@ def get_as_networkx(self, directed=True):
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
nx_graph.add_edge(src_id, dst_id, **rel)
return nx_graph

def _get_columns_to_extract(self):
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [Type.NODE.value, Type.REL.value]:
column_to_extract[i] = (column_type, column_name)
return column_to_extract

def get_as_torch_geometric(self):
self.check_for_query_result_close()
# Despite we are not using torch_geometric in this file, we need to
# import it here to throw an error early if the user does not have
# torch_geometric or torch installed.

import torch
import torch_geometric

converter = TorchGeometricResultConverter(self)
return converter.get_as_torch_geometric()
237 changes: 237 additions & 0 deletions tools/python_api/src_py/torch_geometric_result_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from .types import Type
import warnings


class TorchGeometricResultConverter:
def __init__(self, query_result):
self.query_result = query_result
self.nodes_dict = {}
self.edges_dict = {}
self.rels = set()
self.nodes_property_names_dict = {}
self.table_to_label_dict = {}
self.internal_id_to_pos_dict = {}
self.pos_to_primary_key_dict = {}
self.warning_messages = set()
self.ignored_columns = set()
self.column_to_extract = self.query_result._get_columns_to_extract()

def __get_node_property_names(self, table_name):
if table_name in self.nodes_property_names_dict:
return self.nodes_property_names_dict[table_name]

PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)"
LIST_SYMBOL = "[]"
result_str = self.query_result.connection._connection.get_node_property_names(
table_name)
results = {}
for (i, line) in enumerate(result_str.splitlines()):
# ignore first line
if i == 0:
continue
line = line.strip()
if line == "":
continue
line_splited = line.split(" ")
if len(line_splited) < 2:
continue

prop_name = line_splited[0]
prop_type = " ".join(line_splited[1:])

is_primary_key = PRIMARY_KEY_SYMBOL in prop_type
prop_type = prop_type.replace(PRIMARY_KEY_SYMBOL, "")
depth = prop_type.count(LIST_SYMBOL)
prop_type = prop_type.replace(LIST_SYMBOL, "")
results[prop_name] = {
"type": prop_type,
"depth": depth,
"is_primary_key": is_primary_key
}
self.nodes_property_names_dict[table_name] = results
return results

def __populate_nodes_dict_and_deduplicte_edges(self):
self.query_result.reset_iterator()
while self.query_result.has_next():
row = self.query_result.get_next()
for i in self.column_to_extract:
column_type, _ = self.column_to_extract[i]
if column_type == Type.NODE.value:
node = row[i]
label = node["_label"]
_id = node["_id"]
self.table_to_label_dict[_id["table"]] = label

if (_id["table"], _id["offset"]) in self.internal_id_to_pos_dict:
continue

node_property_names = self.__get_node_property_names(
label)

pos, primary_key = self.__extract_properties_from_node(
node, label, node_property_names)

# If no properties were extracted, then ignore the node
if pos >= 0:
self.internal_id_to_pos_dict[(
_id["table"], _id["offset"])] = pos
if label not in self.pos_to_primary_key_dict:
self.pos_to_primary_key_dict[label] = {}
self.pos_to_primary_key_dict[label][pos] = primary_key

elif column_type == Type.REL.value:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
self.rels.add((_src["table"], _src["offset"],
_dst["table"], _dst["offset"]))

def __extract_properties_from_node(self, node, label, node_property_names):
import torch
for prop_name in node_property_names:
# Ignore properties that are marked as ignored
if (label, prop_name) in self.ignored_columns:
continue

# Ignore primary key
if node_property_names[prop_name]["is_primary_key"]:
primary_key = node[prop_name]
continue

# Ignore properties that are not supported by torch_geometric
if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]:
self.warning_messages.add(
"Property {}.{} of type {} is not supported by torch_geometric. The property is ignored."
.format(label, prop_name, node_property_names[prop_name]["type"]))
self.__ignore_property(label, prop_name)
continue
if node[prop_name] is None:
self.warning_messages.add(
"Property {}.{} has a null value. torch_geometric does not support null values. The property is ignored."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
continue

if node_property_names[prop_name]['depth'] == 0:
curr_value = node[prop_name]
else:
try:
if node_property_names[prop_name]['type'] == Type.INT64.value:
curr_value = torch.LongTensor(node[prop_name])
elif node_property_names[prop_name]['type'] == Type.DOUBLE.value:
curr_value = torch.FloatTensor(node[prop_name])
elif node_property_names[prop_name]['type'] == Type.BOOL.value:
curr_value = torch.BoolTensor(node[prop_name])
except ValueError:
self.warning_messages.add(
"Property {}.{} cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
continue
# Check if the shape of the property is consistent
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
# If the shape is inconsistent, then ignore the property
if curr_value.shape != self.nodes_dict[label][prop_name][0].shape:
self.warning_messages.add(
"Property {}.{} has an inconsistent shape. The property is ignored."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
continue

# Create the dictionary for the label if it does not exist
if label not in self.nodes_dict:
self.nodes_dict[label] = {}
if prop_name not in self.nodes_dict[label]:
self.nodes_dict[label][prop_name] = []

# Add the property to the dictionary
self.nodes_dict[label][prop_name].append(curr_value)

# The pos will be overwritten for each property, but
# it should be the same for all properties
pos = len(self.nodes_dict[label][prop_name]) - 1
return pos, primary_key

def __ignore_property(self, label, prop_name):
self.ignored_columns.add((label, prop_name))
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
del self.nodes_dict[label][prop_name]
if len(self.nodes_dict[label]) == 0:
del self.nodes_dict[label]

def __populate_edges_dict(self):
# Post-process edges, map internal ids to positions
for r in self.rels:
src_pos = self.internal_id_to_pos_dict[(r[0], r[1])]
dst_pos = self.internal_id_to_pos_dict[(r[2], r[3])]
src_label = self.table_to_label_dict[r[0]]
dst_label = self.table_to_label_dict[r[2]]
if src_label not in self.edges_dict:
self.edges_dict[src_label] = {}
if dst_label not in self.edges_dict[src_label]:
self.edges_dict[src_label][dst_label] = []
self.edges_dict[src_label][dst_label].append((src_pos, dst_pos))

def __print_warnings(self):
for message in self.warning_messages:
warnings.warn(message)

def __convert_to_torch_geometric(self):
import torch
import torch_geometric
if len(self.nodes_dict) == 0:
self.warning_messages.add(
"No nodes found or all nodes were ignored. Returning None.")
return None

# If there is only one node type, then convert to torch_geometric.data.Data
# Otherwise, convert to torch_geometric.data.HeteroData
if len(self.nodes_dict) == 1:
data = torch_geometric.data.Data()
is_hetero = False
else:
data = torch_geometric.data.HeteroData()
is_hetero = True

# Convert nodes to tensors
for label in self.nodes_dict:
for prop_name in self.nodes_dict[label]:
prop_type = self.nodes_property_names_dict[label][prop_name]["type"]
prop_depth = self.nodes_property_names_dict[label][prop_name]["depth"]
if prop_depth == 0:
if prop_type == Type.INT64.value:
converted = torch.LongTensor(
self.nodes_dict[label][prop_name])
elif prop_type == Type.BOOL.value:
converted = torch.BoolTensor(
self.nodes_dict[label][prop_name])
elif prop_type == Type.DOUBLE.value:
converted = torch.FloatTensor(
self.nodes_dict[label][prop_name])
else:
converted = torch.stack(
self.nodes_dict[label][prop_name], dim=0)
if is_hetero:
data[label][prop_name] = converted
else:
data[prop_name] = converted

# Convert edges to tensors
for src_label in self.edges_dict:
for dst_label in self.edges_dict[src_label]:
edge_idx = torch.tensor(
self.edges_dict[src_label][dst_label], dtype=torch.long).t().contiguous()
if is_hetero:
data[src_label, dst_label].edge_index = edge_idx
else:
data.edge_index = edge_idx
pos_to_primary_key_dict = self.pos_to_primary_key_dict[
label] if not is_hetero else self.pos_to_primary_key_dict
return data, pos_to_primary_key_dict

def get_as_torch_geometric(self):
self.__populate_nodes_dict_and_deduplicte_edges()
self.__populate_edges_dict()
data, pos_to_primary_key_dict = self.__convert_to_torch_geometric()
self.__print_warnings()
return data, pos_to_primary_key_dict
16 changes: 16 additions & 0 deletions tools/python_api/src_py/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum


class Type(Enum):
"""The type of a value in the database."""
BOOL = "BOOL"
INT64 = "INT64"
DOUBLE = "DOUBLE"
STRING = "STRING"
DATE = "DATE"
TIMESTAMP = "TIMESTAMP"
INTERVAL = "INTERVAL"
LIST = "LIST"
NODE = "NODE"
REL = "REL"
NODE_ID = "NODE_ID"
3 changes: 3 additions & 0 deletions tools/python_api/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def init_tiny_snb(tmp_path):
conn.execute('COPY movies FROM "../../../dataset/tinysnb/vMovies.csv"')
conn.execute('create rel table workAt (FROM person TO organisation, year INT64, MANY_ONE)')
conn.execute('COPY workAt FROM "../../../dataset/tinysnb/eWorkAt.csv"')
conn.execute('create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));')
conn.execute(
'COPY tensor FROM "../../../dataset/tensor-list/vTensor.csv" (HEADER=true)')
return output_path


Expand Down
Loading

0 comments on commit 277112a

Please sign in to comment.