Skip to content

Commit

Permalink
Implmentget_as_torch_geometric
Browse files Browse the repository at this point in the history
  • Loading branch information
mewim committed Jan 25, 2023
1 parent f2b3dba commit 252e441
Show file tree
Hide file tree
Showing 13 changed files with 773 additions and 25 deletions.
22 changes: 13 additions & 9 deletions .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,30 @@ jobs:
runs-on: kuzu-self-hosted-testing
steps:
- uses: actions/checkout@v3
- run: pip install --user -r tools/python_api/requirements_dev.txt

- name: ensure python dependencies
- run: |
pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\
pip install --user -r tools/python_api/requirements_dev.txt
- name: build debug
run: CC=gcc CXX=g++ make alldebug NUM_THREADS=32

- name: run test with asan
run: ctest
env:
LD_PRELOAD: '/usr/lib/x86_64-linux-gnu/libasan.so.6'
ASAN_OPTIONS: 'detect_leaks=1:log_path=/tmp/asan.log'
env:
LD_PRELOAD: "/usr/lib/x86_64-linux-gnu/libasan.so.6"
ASAN_OPTIONS: "detect_leaks=1:log_path=/tmp/asan.log"
working-directory: ./build/debug/test
continue-on-error: true

- name: cat asan log
run: cat /tmp/asan.log* || true
shell: bash

- name: clean-up asan log
run: rm -rf /tmp/asan.log*

clang-build-test:
name: clang build & test
needs: [clang-formatting-check]
Expand All @@ -65,7 +69,7 @@ jobs:

- name: test
run: CC=clang-14 CXX=clang++-14 make test NUM_THREADS=32

- name: python test
run: CC=clang-14 CXX=clang++-14 make pytest NUM_THREADS=32

Expand Down
1 change: 1 addition & 0 deletions dataset/tensor-list/copy.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
COPY tensor FROM "dataset/tensor-list/vTensor.csv" (HEADER=true);
1 change: 1 addition & 0 deletions dataset/tensor-list/schema.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));
7 changes: 7 additions & 0 deletions dataset/tensor-list/vTensor.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ID,boolTensor,doubleTensor,intTensor,oneDimInt
0,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[1,2],[3,4]],[[5,6],[7,8]]]",1
3,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[3,4],[5,6]],[[7,8],[9,10]]]",2
4,"[false,true]","[[0.4,0.8],[0.7,0.6]]","[[[5,6],[7,8]],[[9,10],[11,12]]]",
5,"[true,true]","[[0.4,0.9],[0.5,0.2]]","[[[7,8],[9,10]],[[11,12],[13,14]]]",
6,"[false,true]","[[0.2,0.4],[0.5,0.1]]","[[[9,10],[11,12]],[[13,14],[15,16]]]",5
8,"[false,true]","[[0.6,0.4],[0.6,0.1]]","[[[11,12],[13,14]],[[15,16],[17,18]]]",6
5 changes: 5 additions & 0 deletions tools/python_api/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ pandas
networkx~=3.0.0
numpy
pyarrow~=10.0.0
torch-cluster~=1.6.0
torch-geometric~=2.2.0
torch-scatter~=2.1.0
torch-sparse~=0.6.16
torch-spline-conv~=1.2.1
2 changes: 2 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class PyConnection {

void setMaxNumThreadForExec(uint64_t numThreads);

py::str getNodePropertyNames(const string& tableName);

private:
unordered_map<string, shared_ptr<Value>> transformPythonParameters(py::list params);

Expand Down
7 changes: 6 additions & 1 deletion tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ void PyConnection::initialize(py::handle& m) {
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"));
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"));
PyDateTime_IMPORT;
}

Expand Down Expand Up @@ -37,6 +38,10 @@ void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
conn->setMaxNumThreadForExec(numThreads);
}

py::str PyConnection::getNodePropertyNames(const string& tableName) {
return conn->getNodePropertyNames(tableName);
}

unordered_map<string, shared_ptr<Value>> PyConnection::transformPythonParameters(py::list params) {
unordered_map<string, shared_ptr<Value>> result;
for (auto param : params) {
Expand Down
44 changes: 30 additions & 14 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
NODE_TYPE = "NODE"
REL_TYPE = "REL"
from .torch_geometric_result_converter import TorchGeometricResultConverter
from .types import Type


class QueryResult:
Expand Down Expand Up @@ -65,16 +65,7 @@ def get_as_networkx(self, directed=True):
nx_graph = nx.DiGraph()
else:
nx_graph = nx.Graph()
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [NODE_TYPE, REL_TYPE]:
column_to_extract[i] = (column_type, column_name)
column_to_extract = self._get_columns_to_extract()

self.reset_iterator()

Expand All @@ -87,12 +78,12 @@ def get_as_networkx(self, directed=True):
row = self.get_next()
for i in column_to_extract:
column_type, _ = column_to_extract[i]
if column_type == NODE_TYPE:
if column_type == Type.NODE.value:
_id = row[i]["_id"]
nodes[(_id["table"], _id["offset"])] = row[i]
table_to_label_dict[_id["table"]] = row[i]["_label"]

elif column_type == REL_TYPE:
elif column_type == Type.REL.value:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
Expand All @@ -115,3 +106,28 @@ def get_as_networkx(self, directed=True):
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
nx_graph.add_edge(src_id, dst_id, **rel)
return nx_graph

def _get_columns_to_extract(self):
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [Type.NODE.value, Type.REL.value]:
column_to_extract[i] = (column_type, column_name)
return column_to_extract

def get_as_torch_geometric(self):
self.check_for_query_result_close()
# Despite we are not using torch_geometric in this file, we need to
# import it here to throw an error early if the user does not have
# torch_geometric or torch installed.

import torch
import torch_geometric

converter = TorchGeometricResultConverter(self)
return converter.get_as_torch_geometric()
Loading

0 comments on commit 252e441

Please sign in to comment.