Skip to content

Commit

Permalink
Merge pull request #1200 from kuzudb/get-as-pyg
Browse files Browse the repository at this point in the history
Implment `get_as_torch_geometric`
  • Loading branch information
mewim committed Jan 29, 2023
2 parents 0a8d48f + f766a02 commit 3981257
Show file tree
Hide file tree
Showing 13 changed files with 809 additions and 45 deletions.
74 changes: 47 additions & 27 deletions .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ jobs:
runs-on: kuzu-self-hosted-testing
steps:
- uses: actions/checkout@v3
- run: pip install --user -r tools/python_api/requirements_dev.txt

- name: build
- name: Ensure Python dependencies
run: |
pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\
pip install --user -r tools/python_api/requirements_dev.txt
- name: Build
run: CC=gcc CXX=g++ make release NUM_THREADS=32

- name: test
- name: Test
run: CC=gcc CXX=g++ make test NUM_THREADS=32

- name: python test
- name: Python test
run: CC=gcc CXX=g++ make pytest NUM_THREADS=32

gcc-build-test-with-asan:
Expand All @@ -32,64 +36,80 @@ jobs:
runs-on: kuzu-self-hosted-testing
steps:
- uses: actions/checkout@v3
- run: pip install --user -r tools/python_api/requirements_dev.txt

- name: build debug
- name: Ensure Python dependencies
run: |
pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\
pip install --user -r tools/python_api/requirements_dev.txt
- name: Build debug
run: CC=gcc CXX=g++ make alldebug NUM_THREADS=32
- name: run test with asan

- name: Run test with ASan
run: ctest
env:
LD_PRELOAD: '/usr/lib/x86_64-linux-gnu/libasan.so.6'
ASAN_OPTIONS: 'detect_leaks=1:log_path=/tmp/asan.log'
env:
LD_PRELOAD: "/usr/lib/x86_64-linux-gnu/libasan.so.6"
ASAN_OPTIONS: "detect_leaks=1:log_path=/tmp/asan.log"
working-directory: ./build/debug/test
continue-on-error: true
- name: cat asan log

- name: Display ASan log
run: cat /tmp/asan.log* || true
shell: bash
- name: clean-up asan log

- name: Clean up ASan log
run: rm -rf /tmp/asan.log*

clang-build-test:
name: clang build & test
needs: [clang-formatting-check]
runs-on: kuzu-self-hosted-testing
steps:
- uses: actions/checkout@v3
- run: pip3 install --user -r tools/python_api/requirements_dev.txt

- name: build
- name: Ensure python dependencies
run: |
pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\
pip install --user -r tools/python_api/requirements_dev.txt
- name: Build
run: CC=clang-14 CXX=clang++-14 make release NUM_THREADS=32

- name: test
- name: Test
run: CC=clang-14 CXX=clang++-14 make test NUM_THREADS=32
- name: python test

- name: Python test
run: CC=clang-14 CXX=clang++-14 make pytest NUM_THREADS=32

clang-formatting-check:
name: clang-formatting-check
name: clang-format check
runs-on: kuzu-self-hosted-testing
steps:
- uses: actions/checkout@v3
with:
repository: Sarcasm/run-clang-format
path: run-clang-format
- run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r src/
- run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r test/

- name: Check source format
run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r src/

- name: Check test format
run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r test/

benchmark:
name: benchmark
needs: [gcc-build-test, clang-build-test]
runs-on: kuzu-self-hosted-benchmarking
steps:
- uses: actions/checkout@v3
- run: pip3 install --user -r tools/python_api/requirements_dev.txt

- name: Ensure Python dependencies
run: |
pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\
pip install --user -r tools/python_api/requirements_dev.txt
- name: build
- name: Build
run: make benchmark NUM_THREADS=30

- name: benchmark
- name: Benchmark
run: python3 benchmark/benchmark_runner.py --dataset ldbc-sf100 --thread 1
1 change: 1 addition & 0 deletions dataset/tensor-list/copy.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
COPY tensor FROM "dataset/tensor-list/vTensor.csv" (HEADER=true);
1 change: 1 addition & 0 deletions dataset/tensor-list/schema.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));
7 changes: 7 additions & 0 deletions dataset/tensor-list/vTensor.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ID,boolTensor,doubleTensor,intTensor,oneDimInt
0,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[1,2],[3,4]],[[5,6],[7,8]]]",1
3,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[3,4],[5,6]],[[7,8],[9,10]]]",2
4,"[false,true]","[[0.4,0.8],[0.7,0.6]]","[[[5,6],[7,8]],[[9,10],[11,12]]]",
5,"[true,true]","[[0.4,0.9],[0.5,0.2]]","[[[7,8],[9,10]],[[11,12],[13,14]]]",
6,"[false,true]","[[0.2,0.4],[0.5,0.1]]","[[[9,10],[11,12]],[[13,14],[15,16]]]",5
8,"[false,true]","[[0.6,0.4],[0.6,0.1]]","[[[11,12],[13,14]],[[15,16],[17,18]]]",6
5 changes: 5 additions & 0 deletions tools/python_api/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ pandas
networkx~=3.0.0
numpy
pyarrow~=10.0.0
torch-cluster~=1.6.0
torch-geometric~=2.2.0
torch-scatter~=2.1.0
torch-sparse~=0.6.16
torch-spline-conv~=1.2.1
2 changes: 2 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class PyConnection {

void setMaxNumThreadForExec(uint64_t numThreads);

py::str getNodePropertyNames(const string& tableName);

private:
unordered_map<string, shared_ptr<Value>> transformPythonParameters(py::list params);

Expand Down
7 changes: 6 additions & 1 deletion tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ void PyConnection::initialize(py::handle& m) {
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"));
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"));
PyDateTime_IMPORT;
}

Expand Down Expand Up @@ -37,6 +38,10 @@ void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
conn->setMaxNumThreadForExec(numThreads);
}

py::str PyConnection::getNodePropertyNames(const string& tableName) {
return conn->getNodePropertyNames(tableName);
}

unordered_map<string, shared_ptr<Value>> PyConnection::transformPythonParameters(py::list params) {
unordered_map<string, shared_ptr<Value>> result;
for (auto param : params) {
Expand Down
48 changes: 32 additions & 16 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
NODE_TYPE = "NODE"
REL_TYPE = "REL"
from .torch_geometric_result_converter import TorchGeometricResultConverter
from .types import Type


class QueryResult:
Expand Down Expand Up @@ -65,16 +65,7 @@ def get_as_networkx(self, directed=True):
nx_graph = nx.DiGraph()
else:
nx_graph = nx.Graph()
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [NODE_TYPE, REL_TYPE]:
column_to_extract[i] = (column_type, column_name)
properties_to_extract = self._get_properties_to_extract()

self.reset_iterator()

Expand All @@ -85,14 +76,14 @@ def get_as_networkx(self, directed=True):
# De-duplicate nodes and rels
while self.has_next():
row = self.get_next()
for i in column_to_extract:
column_type, _ = column_to_extract[i]
if column_type == NODE_TYPE:
for i in properties_to_extract:
column_type, _ = properties_to_extract[i]
if column_type == Type.NODE.value:
_id = row[i]["_id"]
nodes[(_id["table"], _id["offset"])] = row[i]
table_to_label_dict[_id["table"]] = row[i]["_label"]

elif column_type == REL_TYPE:
elif column_type == Type.REL.value:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
Expand All @@ -115,3 +106,28 @@ def get_as_networkx(self, directed=True):
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
nx_graph.add_edge(src_id, dst_id, **rel)
return nx_graph

def _get_properties_to_extract(self):
column_names = self.get_column_names()
column_types = self.get_column_data_types()
properties_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [Type.NODE.value, Type.REL.value]:
properties_to_extract[i] = (column_type, column_name)
return properties_to_extract

def get_as_torch_geometric(self):
self.check_for_query_result_close()
# Despite we are not using torch_geometric in this file, we need to
# import it here to throw an error early if the user does not have
# torch_geometric or torch installed.

import torch
import torch_geometric

converter = TorchGeometricResultConverter(self)
return converter.get_as_torch_geometric()
Loading

0 comments on commit 3981257

Please sign in to comment.