diff --git a/.github/workflows/ci-workflow.yml b/.github/workflows/ci-workflow.yml index cb1263079f0..ca575e06cb9 100644 --- a/.github/workflows/ci-workflow.yml +++ b/.github/workflows/ci-workflow.yml @@ -15,15 +15,19 @@ jobs: runs-on: kuzu-self-hosted-testing steps: - uses: actions/checkout@v3 - - run: pip install --user -r tools/python_api/requirements_dev.txt - - name: build + - name: Ensure Python dependencies + run: | + pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\ + pip install --user -r tools/python_api/requirements_dev.txt + + - name: Build run: CC=gcc CXX=g++ make release NUM_THREADS=32 - - name: test + - name: Test run: CC=gcc CXX=g++ make test NUM_THREADS=32 - - name: python test + - name: Python test run: CC=gcc CXX=g++ make pytest NUM_THREADS=32 gcc-build-test-with-asan: @@ -32,53 +36,65 @@ jobs: runs-on: kuzu-self-hosted-testing steps: - uses: actions/checkout@v3 - - run: pip install --user -r tools/python_api/requirements_dev.txt - - name: build debug + - name: Ensure Python dependencies + run: | + pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\ + pip install --user -r tools/python_api/requirements_dev.txt + + - name: Build debug run: CC=gcc CXX=g++ make alldebug NUM_THREADS=32 - - - name: run test with asan + + - name: Run test with ASan run: ctest - env: - LD_PRELOAD: '/usr/lib/x86_64-linux-gnu/libasan.so.6' - ASAN_OPTIONS: 'detect_leaks=1:log_path=/tmp/asan.log' + env: + LD_PRELOAD: "/usr/lib/x86_64-linux-gnu/libasan.so.6" + ASAN_OPTIONS: "detect_leaks=1:log_path=/tmp/asan.log" working-directory: ./build/debug/test continue-on-error: true - - - name: cat asan log + + - name: Display ASan log run: cat /tmp/asan.log* || true shell: bash - - - name: clean-up asan log + + - name: Clean up ASan log run: rm -rf /tmp/asan.log* - + clang-build-test: name: clang build & test needs: [clang-formatting-check] runs-on: kuzu-self-hosted-testing steps: - uses: actions/checkout@v3 - - run: pip3 install --user -r tools/python_api/requirements_dev.txt - - name: build + - name: Ensure python dependencies + run: | + pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\ + pip install --user -r tools/python_api/requirements_dev.txt + + - name: Build run: CC=clang-14 CXX=clang++-14 make release NUM_THREADS=32 - - name: test + - name: Test run: CC=clang-14 CXX=clang++-14 make test NUM_THREADS=32 - - - name: python test + + - name: Python test run: CC=clang-14 CXX=clang++-14 make pytest NUM_THREADS=32 clang-formatting-check: - name: clang-formatting-check + name: clang-format check runs-on: kuzu-self-hosted-testing steps: - uses: actions/checkout@v3 with: repository: Sarcasm/run-clang-format path: run-clang-format - - run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r src/ - - run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r test/ + + - name: Check source format + run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r src/ + + - name: Check test format + run: python3 run-clang-format/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r test/ benchmark: name: benchmark @@ -86,10 +102,14 @@ jobs: runs-on: kuzu-self-hosted-benchmarking steps: - uses: actions/checkout@v3 - - run: pip3 install --user -r tools/python_api/requirements_dev.txt + + - name: Ensure Python dependencies + run: | + pip install torch~=1.13 --extra-index-url https://download.pytorch.org/whl/cpu &&\ + pip install --user -r tools/python_api/requirements_dev.txt - - name: build + - name: Build run: make benchmark NUM_THREADS=30 - - name: benchmark + - name: Benchmark run: python3 benchmark/benchmark_runner.py --dataset ldbc-sf100 --thread 1 diff --git a/dataset/tensor-list/copy.cypher b/dataset/tensor-list/copy.cypher new file mode 100644 index 00000000000..0d5aeb6bd5b --- /dev/null +++ b/dataset/tensor-list/copy.cypher @@ -0,0 +1 @@ +COPY tensor FROM "dataset/tensor-list/vTensor.csv" (HEADER=true); diff --git a/dataset/tensor-list/schema.cypher b/dataset/tensor-list/schema.cypher new file mode 100644 index 00000000000..2039c6e9eea --- /dev/null +++ b/dataset/tensor-list/schema.cypher @@ -0,0 +1 @@ +create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID)); diff --git a/dataset/tensor-list/vTensor.csv b/dataset/tensor-list/vTensor.csv new file mode 100644 index 00000000000..809ea7530f5 --- /dev/null +++ b/dataset/tensor-list/vTensor.csv @@ -0,0 +1,7 @@ +ID,boolTensor,doubleTensor,intTensor,oneDimInt +0,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[1,2],[3,4]],[[5,6],[7,8]]]",1 +3,"[true,false]","[[0.1,0.2],[0.3,0.4]]","[[[3,4],[5,6]],[[7,8],[9,10]]]",2 +4,"[false,true]","[[0.4,0.8],[0.7,0.6]]","[[[5,6],[7,8]],[[9,10],[11,12]]]", +5,"[true,true]","[[0.4,0.9],[0.5,0.2]]","[[[7,8],[9,10]],[[11,12],[13,14]]]", +6,"[false,true]","[[0.2,0.4],[0.5,0.1]]","[[[9,10],[11,12]],[[13,14],[15,16]]]",5 +8,"[false,true]","[[0.6,0.4],[0.6,0.1]]","[[[11,12],[13,14]],[[15,16],[17,18]]]",6 diff --git a/tools/python_api/requirements_dev.txt b/tools/python_api/requirements_dev.txt index e297f226778..b1ee48444fd 100644 --- a/tools/python_api/requirements_dev.txt +++ b/tools/python_api/requirements_dev.txt @@ -4,3 +4,8 @@ pandas networkx~=3.0.0 numpy pyarrow~=10.0.0 +torch-cluster~=1.6.0 +torch-geometric~=2.2.0 +torch-scatter~=2.1.0 +torch-sparse~=0.6.16 +torch-spline-conv~=1.2.1 \ No newline at end of file diff --git a/tools/python_api/src_cpp/include/py_connection.h b/tools/python_api/src_cpp/include/py_connection.h index fd343fea485..90d15fe453b 100644 --- a/tools/python_api/src_cpp/include/py_connection.h +++ b/tools/python_api/src_cpp/include/py_connection.h @@ -16,6 +16,8 @@ class PyConnection { void setMaxNumThreadForExec(uint64_t numThreads); + py::str getNodePropertyNames(const string& tableName); + private: unordered_map> transformPythonParameters(py::list params); diff --git a/tools/python_api/src_cpp/py_connection.cpp b/tools/python_api/src_cpp/py_connection.cpp index 6ac1f535b6e..5e2e525d7a7 100644 --- a/tools/python_api/src_cpp/py_connection.cpp +++ b/tools/python_api/src_cpp/py_connection.cpp @@ -8,7 +8,8 @@ void PyConnection::initialize(py::handle& m) { .def( "execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list()) .def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec, - py::arg("num_threads")); + py::arg("num_threads")) + .def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name")); PyDateTime_IMPORT; } @@ -37,6 +38,10 @@ void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) { conn->setMaxNumThreadForExec(numThreads); } +py::str PyConnection::getNodePropertyNames(const string& tableName) { + return conn->getNodePropertyNames(tableName); +} + unordered_map> PyConnection::transformPythonParameters(py::list params) { unordered_map> result; for (auto param : params) { diff --git a/tools/python_api/src_py/query_result.py b/tools/python_api/src_py/query_result.py index ca2c6208941..f068f2d3d13 100644 --- a/tools/python_api/src_py/query_result.py +++ b/tools/python_api/src_py/query_result.py @@ -1,5 +1,5 @@ -NODE_TYPE = "NODE" -REL_TYPE = "REL" +from .torch_geometric_result_converter import TorchGeometricResultConverter +from .types import Type class QueryResult: @@ -65,16 +65,7 @@ def get_as_networkx(self, directed=True): nx_graph = nx.DiGraph() else: nx_graph = nx.Graph() - column_names = self.get_column_names() - column_types = self.get_column_data_types() - column_to_extract = {} - - # Iterate over columns and extract nodes and rels, ignoring other columns - for i in range(len(column_names)): - column_name = column_names[i] - column_type = column_types[i] - if column_type in [NODE_TYPE, REL_TYPE]: - column_to_extract[i] = (column_type, column_name) + column_to_extract = self._get_columns_to_extract() self.reset_iterator() @@ -87,12 +78,12 @@ def get_as_networkx(self, directed=True): row = self.get_next() for i in column_to_extract: column_type, _ = column_to_extract[i] - if column_type == NODE_TYPE: + if column_type == Type.NODE.value: _id = row[i]["_id"] nodes[(_id["table"], _id["offset"])] = row[i] table_to_label_dict[_id["table"]] = row[i]["_label"] - elif column_type == REL_TYPE: + elif column_type == Type.REL.value: _src = row[i]["_src"] _dst = row[i]["_dst"] rels[(_src["table"], _src["offset"], _dst["table"], @@ -115,3 +106,28 @@ def get_as_networkx(self, directed=True): table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"]) nx_graph.add_edge(src_id, dst_id, **rel) return nx_graph + + def _get_columns_to_extract(self): + column_names = self.get_column_names() + column_types = self.get_column_data_types() + column_to_extract = {} + + # Iterate over columns and extract nodes and rels, ignoring other columns + for i in range(len(column_names)): + column_name = column_names[i] + column_type = column_types[i] + if column_type in [Type.NODE.value, Type.REL.value]: + column_to_extract[i] = (column_type, column_name) + return column_to_extract + + def get_as_torch_geometric(self): + self.check_for_query_result_close() + # Despite we are not using torch_geometric in this file, we need to + # import it here to throw an error early if the user does not have + # torch_geometric or torch installed. + + import torch + import torch_geometric + + converter = TorchGeometricResultConverter(self) + return converter.get_as_torch_geometric() diff --git a/tools/python_api/src_py/torch_geometric_result_converter.py b/tools/python_api/src_py/torch_geometric_result_converter.py new file mode 100644 index 00000000000..99874c15061 --- /dev/null +++ b/tools/python_api/src_py/torch_geometric_result_converter.py @@ -0,0 +1,237 @@ +from .types import Type +import warnings + + +class TorchGeometricResultConverter: + def __init__(self, query_result): + self.query_result = query_result + self.nodes_dict = {} + self.edges_dict = {} + self.rels = set() + self.nodes_property_names_dict = {} + self.table_to_label_dict = {} + self.internal_id_to_pos_dict = {} + self.pos_to_primary_key_dict = {} + self.warning_messages = set() + self.ignored_columns = set() + self.column_to_extract = self.query_result._get_columns_to_extract() + + def __get_node_property_names(self, table_name): + if table_name in self.nodes_property_names_dict: + return self.nodes_property_names_dict[table_name] + + PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)" + LIST_SYMBOL = "[]" + result_str = self.query_result.connection._connection.get_node_property_names( + table_name) + results = {} + for (i, line) in enumerate(result_str.splitlines()): + # ignore first line + if i == 0: + continue + line = line.strip() + if line == "": + continue + line_splited = line.split(" ") + if len(line_splited) < 2: + continue + + prop_name = line_splited[0] + prop_type = " ".join(line_splited[1:]) + + is_primary_key = PRIMARY_KEY_SYMBOL in prop_type + prop_type = prop_type.replace(PRIMARY_KEY_SYMBOL, "") + depth = prop_type.count(LIST_SYMBOL) + prop_type = prop_type.replace(LIST_SYMBOL, "") + results[prop_name] = { + "type": prop_type, + "depth": depth, + "is_primary_key": is_primary_key + } + self.nodes_property_names_dict[table_name] = results + return results + + def __populate_nodes_dict_and_deduplicte_edges(self): + self.query_result.reset_iterator() + while self.query_result.has_next(): + row = self.query_result.get_next() + for i in self.column_to_extract: + column_type, _ = self.column_to_extract[i] + if column_type == Type.NODE.value: + node = row[i] + label = node["_label"] + _id = node["_id"] + self.table_to_label_dict[_id["table"]] = label + + if (_id["table"], _id["offset"]) in self.internal_id_to_pos_dict: + continue + + node_property_names = self.__get_node_property_names( + label) + + pos, primary_key = self.__extract_properties_from_node( + node, label, node_property_names) + + # If no properties were extracted, then ignore the node + if pos >= 0: + self.internal_id_to_pos_dict[( + _id["table"], _id["offset"])] = pos + if label not in self.pos_to_primary_key_dict: + self.pos_to_primary_key_dict[label] = {} + self.pos_to_primary_key_dict[label][pos] = primary_key + + elif column_type == Type.REL.value: + _src = row[i]["_src"] + _dst = row[i]["_dst"] + self.rels.add((_src["table"], _src["offset"], + _dst["table"], _dst["offset"])) + + def __extract_properties_from_node(self, node, label, node_property_names): + import torch + for prop_name in node_property_names: + # Ignore properties that are marked as ignored + if (label, prop_name) in self.ignored_columns: + continue + + # Ignore primary key + if node_property_names[prop_name]["is_primary_key"]: + primary_key = node[prop_name] + continue + + # Ignore properties that are not supported by torch_geometric + if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]: + self.warning_messages.add( + "Property {}.{} of type {} is not supported by torch_geometric. The property is ignored." + .format(label, prop_name, node_property_names[prop_name]["type"])) + self.__ignore_property(label, prop_name) + continue + if node[prop_name] is None: + self.warning_messages.add( + "Property {}.{} has a null value. torch_geometric does not support null values. The property is ignored." + .format(label, prop_name)) + self.__ignore_property(label, prop_name) + continue + + if node_property_names[prop_name]['depth'] == 0: + curr_value = node[prop_name] + else: + try: + if node_property_names[prop_name]['type'] == Type.INT64.value: + curr_value = torch.LongTensor(node[prop_name]) + elif node_property_names[prop_name]['type'] == Type.DOUBLE.value: + curr_value = torch.FloatTensor(node[prop_name]) + elif node_property_names[prop_name]['type'] == Type.BOOL.value: + curr_value = torch.BoolTensor(node[prop_name]) + except ValueError: + self.warning_messages.add( + "Property {}.{} cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored." + .format(label, prop_name)) + self.__ignore_property(label, prop_name) + continue + # Check if the shape of the property is consistent + if label in self.nodes_dict and prop_name in self.nodes_dict[label]: + # If the shape is inconsistent, then ignore the property + if curr_value.shape != self.nodes_dict[label][prop_name][0].shape: + self.warning_messages.add( + "Property {}.{} has an inconsistent shape. The property is ignored." + .format(label, prop_name)) + self.__ignore_property(label, prop_name) + continue + + # Create the dictionary for the label if it does not exist + if label not in self.nodes_dict: + self.nodes_dict[label] = {} + if prop_name not in self.nodes_dict[label]: + self.nodes_dict[label][prop_name] = [] + + # Add the property to the dictionary + self.nodes_dict[label][prop_name].append(curr_value) + + # The pos will be overwritten for each property, but + # it should be the same for all properties + pos = len(self.nodes_dict[label][prop_name]) - 1 + return pos, primary_key + + def __ignore_property(self, label, prop_name): + self.ignored_columns.add((label, prop_name)) + if label in self.nodes_dict and prop_name in self.nodes_dict[label]: + del self.nodes_dict[label][prop_name] + if len(self.nodes_dict[label]) == 0: + del self.nodes_dict[label] + + def __populate_edges_dict(self): + # Post-process edges, map internal ids to positions + for r in self.rels: + src_pos = self.internal_id_to_pos_dict[(r[0], r[1])] + dst_pos = self.internal_id_to_pos_dict[(r[2], r[3])] + src_label = self.table_to_label_dict[r[0]] + dst_label = self.table_to_label_dict[r[2]] + if src_label not in self.edges_dict: + self.edges_dict[src_label] = {} + if dst_label not in self.edges_dict[src_label]: + self.edges_dict[src_label][dst_label] = [] + self.edges_dict[src_label][dst_label].append((src_pos, dst_pos)) + + def __print_warnings(self): + for message in self.warning_messages: + warnings.warn(message) + + def __convert_to_torch_geometric(self): + import torch + import torch_geometric + if len(self.nodes_dict) == 0: + self.warning_messages.add( + "No nodes found or all nodes were ignored. Returning None.") + return None + + # If there is only one node type, then convert to torch_geometric.data.Data + # Otherwise, convert to torch_geometric.data.HeteroData + if len(self.nodes_dict) == 1: + data = torch_geometric.data.Data() + is_hetero = False + else: + data = torch_geometric.data.HeteroData() + is_hetero = True + + # Convert nodes to tensors + for label in self.nodes_dict: + for prop_name in self.nodes_dict[label]: + prop_type = self.nodes_property_names_dict[label][prop_name]["type"] + prop_depth = self.nodes_property_names_dict[label][prop_name]["depth"] + if prop_depth == 0: + if prop_type == Type.INT64.value: + converted = torch.LongTensor( + self.nodes_dict[label][prop_name]) + elif prop_type == Type.BOOL.value: + converted = torch.BoolTensor( + self.nodes_dict[label][prop_name]) + elif prop_type == Type.DOUBLE.value: + converted = torch.FloatTensor( + self.nodes_dict[label][prop_name]) + else: + converted = torch.stack( + self.nodes_dict[label][prop_name], dim=0) + if is_hetero: + data[label][prop_name] = converted + else: + data[prop_name] = converted + + # Convert edges to tensors + for src_label in self.edges_dict: + for dst_label in self.edges_dict[src_label]: + edge_idx = torch.tensor( + self.edges_dict[src_label][dst_label], dtype=torch.long).t().contiguous() + if is_hetero: + data[src_label, dst_label].edge_index = edge_idx + else: + data.edge_index = edge_idx + pos_to_primary_key_dict = self.pos_to_primary_key_dict[ + label] if not is_hetero else self.pos_to_primary_key_dict + return data, pos_to_primary_key_dict + + def get_as_torch_geometric(self): + self.__populate_nodes_dict_and_deduplicte_edges() + self.__populate_edges_dict() + data, pos_to_primary_key_dict = self.__convert_to_torch_geometric() + self.__print_warnings() + return data, pos_to_primary_key_dict diff --git a/tools/python_api/src_py/types.py b/tools/python_api/src_py/types.py new file mode 100644 index 00000000000..d9d8030ae01 --- /dev/null +++ b/tools/python_api/src_py/types.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class Type(Enum): + """The type of a value in the database.""" + BOOL = "BOOL" + INT64 = "INT64" + DOUBLE = "DOUBLE" + STRING = "STRING" + DATE = "DATE" + TIMESTAMP = "TIMESTAMP" + INTERVAL = "INTERVAL" + LIST = "LIST" + NODE = "NODE" + REL = "REL" + NODE_ID = "NODE_ID" diff --git a/tools/python_api/test/conftest.py b/tools/python_api/test/conftest.py index 110f7927a48..50cbb21e0c0 100644 --- a/tools/python_api/test/conftest.py +++ b/tools/python_api/test/conftest.py @@ -31,6 +31,9 @@ def init_tiny_snb(tmp_path): conn.execute('COPY movies FROM "../../../dataset/tinysnb/vMovies.csv"') conn.execute('create rel table workAt (FROM person TO organisation, year INT64, MANY_ONE)') conn.execute('COPY workAt FROM "../../../dataset/tinysnb/eWorkAt.csv"') + conn.execute('create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));') + conn.execute( + 'COPY tensor FROM "../../../dataset/tensor-list/vTensor.csv" (HEADER=true)') return output_path diff --git a/tools/python_api/test/test_main.py b/tools/python_api/test/test_main.py index f13d2befc61..93bdb5cc4ca 100644 --- a/tools/python_api/test/test_main.py +++ b/tools/python_api/test/test_main.py @@ -1,5 +1,6 @@ import pytest +from test_arrow import * from test_datatype import * from test_df import * from test_exception import * @@ -7,8 +8,8 @@ from test_networkx import * from test_parameter import * from test_query_result_close import * +from test_torch_geometric import * from test_write_to_csv import * -from test_arrow import * if __name__ == "__main__": raise SystemExit(pytest.main([__file__])) diff --git a/tools/python_api/test/test_torch_geometric.py b/tools/python_api/test/test_torch_geometric.py new file mode 100644 index 00000000000..3922a0c8791 --- /dev/null +++ b/tools/python_api/test/test_torch_geometric.py @@ -0,0 +1,450 @@ +import datetime +import torch +import warnings + +TINY_SNB_PERSONS_GROUND_TRUTH = {0: {'ID': 0, + 'fName': 'Alice', + 'gender': 1, + 'isStudent': True, + 'isWorker': False, + 'age': 35, + 'eyeSight': 5.0, + 'birthdate': datetime.date(1900, 1, 1), + 'registerTime': datetime.datetime(2011, 8, 20, 11, 25, 30), + 'lastJobDuration': datetime.timedelta(days=1082, seconds=46920), + 'workedHours': [10, 5], + 'usedNames': ['Aida'], + 'courseScoresPerTerm': [[10, 8], [6, 7, 8]], + '_label': 'person', + '_id': {'offset': 0, 'table': 0}}, + 2: {'ID': 2, + 'fName': 'Bob', + 'gender': 2, + 'isStudent': True, + 'isWorker': False, + 'age': 30, + 'eyeSight': 5.1, + 'birthdate': datetime.date(1900, 1, 1), + 'registerTime': datetime.datetime(2008, 11, 3, 15, 25, 30, 526), + 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'workedHours': [12, 8], + 'usedNames': ['Bobby'], + 'courseScoresPerTerm': [[8, 9], [9, 10]], + '_label': 'person', + '_id': {'offset': 1, 'table': 0}}, + 3: {'ID': 3, + 'fName': 'Carol', + 'gender': 1, + 'isStudent': False, + 'isWorker': True, + 'age': 45, + 'eyeSight': 5.0, + 'birthdate': datetime.date(1940, 6, 22), + 'registerTime': datetime.datetime(1911, 8, 20, 2, 32, 21), + 'lastJobDuration': datetime.timedelta(days=2, seconds=1451), + 'workedHours': [4, 5], + 'usedNames': ['Carmen', 'Fred'], + 'courseScoresPerTerm': [[8, 10]], + '_label': 'person', + '_id': {'offset': 2, 'table': 0}}, + 5: {'ID': 5, + 'fName': 'Dan', + 'gender': 2, + 'isStudent': False, + 'isWorker': True, + 'age': 20, + 'eyeSight': 4.8, + 'birthdate': datetime.date(1950, 7, 23), + 'registerTime': datetime.datetime(2031, 11, 30, 12, 25, 30), + 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'workedHours': [1, 9], + 'usedNames': ['Wolfeschlegelstein', 'Daniel'], + 'courseScoresPerTerm': [[7, 4], [8, 8], [9]], + '_label': 'person', + '_id': {'offset': 3, 'table': 0}}, + 7: {'ID': 7, + 'fName': 'Elizabeth', + 'gender': 1, + 'isStudent': False, + 'isWorker': True, + 'age': 20, + 'eyeSight': 4.7, + 'birthdate': datetime.date(1980, 10, 26), + 'registerTime': datetime.datetime(1976, 12, 23, 11, 21, 42), + 'lastJobDuration': datetime.timedelta(days=2, seconds=1451), + 'workedHours': [2], + 'usedNames': ['Ein'], + 'courseScoresPerTerm': [[6], [7], [8]], + '_label': 'person', + '_id': {'offset': 4, 'table': 0}}, + 8: {'ID': 8, + 'fName': 'Farooq', + 'gender': 2, + 'isStudent': True, + 'isWorker': False, + 'age': 25, + 'eyeSight': 4.5, + 'birthdate': datetime.date(1980, 10, 26), + 'registerTime': datetime.datetime(1972, 7, 31, 13, 22, 30, 678559), + 'lastJobDuration': datetime.timedelta(seconds=1080, microseconds=24000), + 'workedHours': [3, 4, 5, 6, 7], + 'usedNames': ['Fesdwe'], + 'courseScoresPerTerm': [[8]], + '_label': 'person', + '_id': {'offset': 5, 'table': 0}}, + 9: {'ID': 9, + 'fName': 'Greg', + 'gender': 2, + 'isStudent': False, + 'isWorker': False, + 'age': 40, + 'eyeSight': 4.9, + 'birthdate': datetime.date(1980, 10, 26), + 'registerTime': datetime.datetime(1976, 12, 23, 4, 41, 42), + 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'workedHours': [1], + 'usedNames': ['Grad'], + 'courseScoresPerTerm': [[10]], + '_label': 'person', + '_id': {'offset': 6, 'table': 0}}, + 10: {'ID': 10, + 'fName': 'Hubert Blaine Wolfeschlegelsteinhausenbergerdorff', + 'gender': 2, + 'isStudent': False, + 'isWorker': True, + 'age': 83, + 'eyeSight': 4.9, + 'birthdate': datetime.date(1990, 11, 27), + 'registerTime': datetime.datetime(2023, 2, 21, 13, 25, 30), + 'lastJobDuration': datetime.timedelta(days=1082, seconds=46920), + 'workedHours': [10, 11, 12, 3, 4, 5, 6, 7], + 'usedNames': ['Ad', 'De', 'Hi', 'Kye', 'Orlan'], + 'courseScoresPerTerm': [[7], [10], [6, 7]], + '_label': 'person', + '_id': {'offset': 7, 'table': 0}}} + +TINY_SNB_ORGANISATIONS_GROUND_TRUTH = {1: {'ID': 1, + 'name': 'ABFsUni', + 'orgCode': 325, + 'mark': 3.7, + 'score': -2, + 'history': '10 years 5 months 13 hours 24 us', + 'licenseValidInterval': datetime.timedelta(days=1085), + 'rating': 1.0, + '_label': 'organisation', + '_id': {'offset': 0, 'table': 2}}, + 4: {'ID': 4, + 'name': 'CsWork', + 'orgCode': 934, + 'mark': 4.1, + 'score': -100, + 'history': '2 years 4 days 10 hours', + 'licenseValidInterval': datetime.timedelta(days=9414), + 'rating': 0.78, + '_label': 'organisation', + '_id': {'offset': 1, 'table': 2}}, + 6: {'ID': 6, + 'name': 'DEsWork', + 'orgCode': 824, + 'mark': 4.1, + 'score': 7, + 'history': '2 years 4 hours 22 us 34 minutes', + 'licenseValidInterval': datetime.timedelta(days=3, seconds=36000, microseconds=100000), + 'rating': 0.52, + '_label': 'organisation', + '_id': {'offset': 2, 'table': 2}}} + +TINY_SNB_KNOWS_GROUND_TRUTH = { + 0: [2, 3, 5], + 2: [0, 3, 5], + 3: [0, 2, 5], + 5: [0, 2, 3], + 7: [8, 9], +} + +TINY_SNB_WORKS_AT_GROUND_TRUTH = { + 3: [4], + 5: [6], + 7: [6], +} + +TENSOR_LIST_GROUND_TRUTH = { + 0: { + 'boolTensor': [True, False], + 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], + 'intTensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + }, + 3: { + 'boolTensor': [True, False], + 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], + 'intTensor': [[[3, 4], [5, 6]], [[7, 8], [9, 10]]] + }, + 4: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.4, 0.8], [0.7, 0.6]], + 'intTensor': [[[5, 6], [7, 8]], [[9, 10], [11, 12]]] + }, + 5: { + 'boolTensor': [True, True], + 'doubleTensor': [[0.4, 0.9], [0.5, 0.2]], + 'intTensor': [[[7, 8], [9, 10]], [[11, 12], [13, 14]]] + }, + 6: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.2, 0.4], [0.5, 0.1]], + 'intTensor': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]] + }, + 8: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.6, 0.4], [0.6, 0.1]], + 'intTensor': [[[11, 12], [13, 14]], [[15, 16], [17, 18]]] + } +} + + +def test_to_torch_geometric_nodes_only(establish_connection): + conn, _ = establish_connection + query = "MATCH (p:person) return p" + + res = conn.execute(query) + with warnings.catch_warnings(record=True) as ws: + torch_geometric_data, pos_to_idx = res.get_as_torch_geometric() + warnings_ground_truth = set([ + "Property person.courseScoresPerTerm cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored.", + "Property person.lastJobDuration of type INTERVAL is not supported by torch_geometric. The property is ignored.", + "Property person.registerTime of type TIMESTAMP is not supported by torch_geometric. The property is ignored.", + "Property person.birthdate of type DATE is not supported by torch_geometric. The property is ignored.", + "Property person.fName of type STRING is not supported by torch_geometric. The property is ignored.", + "Property person.workedHours has an inconsistent shape. The property is ignored.", + "Property person.usedNames of type STRING is not supported by torch_geometric. The property is ignored.", + ]) + assert len(ws) == 7 + for w in ws: + assert str(w.message) in warnings_ground_truth + + assert torch_geometric_data.gender.shape == torch.Size([8]) + assert torch_geometric_data.gender.dtype == torch.int64 + for i in range(8): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['gender'] == torch_geometric_data.gender[i].item() + + assert torch_geometric_data.isStudent.shape == torch.Size([8]) + assert torch_geometric_data.isStudent.dtype == torch.bool + for i in range(8): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['isStudent'] == torch_geometric_data.isStudent[i].item() + + assert torch_geometric_data.isWorker.shape == torch.Size([8]) + assert torch_geometric_data.isWorker.dtype == torch.bool + for i in range(8): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['isWorker'] == torch_geometric_data.isWorker[i].item() + + assert torch_geometric_data.age.shape == torch.Size([8]) + assert torch_geometric_data.age.dtype == torch.int64 + for i in range(8): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['age'] == torch_geometric_data.age[i].item() + + assert torch_geometric_data.eyeSight.shape == torch.Size([8]) + assert torch_geometric_data.eyeSight.dtype == torch.float32 + for i in range(8): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]]['eyeSight'] - \ + torch_geometric_data.eyeSight[i].item() < 1e-6 + + +def test_to_torch_geometric_homogeneous_graph(establish_connection): + conn, _ = establish_connection + query = "MATCH (p:person)-[r:knows]->(q:person) RETURN p, r, q" + + res = conn.execute(query) + with warnings.catch_warnings(record=True) as ws: + torch_geometric_data, pos_to_idx = res.get_as_torch_geometric() + warnings_ground_truth = set([ + "Property person.courseScoresPerTerm cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored.", + "Property person.lastJobDuration of type INTERVAL is not supported by torch_geometric. The property is ignored.", + "Property person.registerTime of type TIMESTAMP is not supported by torch_geometric. The property is ignored.", + "Property person.birthdate of type DATE is not supported by torch_geometric. The property is ignored.", + "Property person.fName of type STRING is not supported by torch_geometric. The property is ignored.", + "Property person.workedHours has an inconsistent shape. The property is ignored.", + "Property person.usedNames of type STRING is not supported by torch_geometric. The property is ignored.", + ]) + assert len(ws) == 7 + for w in ws: + assert str(w.message) in warnings_ground_truth + + assert torch_geometric_data.gender.shape == torch.Size([7]) + assert torch_geometric_data.gender.dtype == torch.int64 + for i in range(7): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['gender'] == torch_geometric_data.gender[i].item() + + assert torch_geometric_data.isStudent.shape == torch.Size([7]) + assert torch_geometric_data.isStudent.dtype == torch.bool + for i in range(7): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['isStudent'] == torch_geometric_data.isStudent[i].item() + + assert torch_geometric_data.isWorker.shape == torch.Size([7]) + assert torch_geometric_data.isWorker.dtype == torch.bool + for i in range(7): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['isWorker'] == torch_geometric_data.isWorker[i].item() + + assert torch_geometric_data.age.shape == torch.Size([7]) + assert torch_geometric_data.age.dtype == torch.int64 + for i in range(7): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + ]['age'] == torch_geometric_data.age[i].item() + + assert torch_geometric_data.eyeSight.shape == torch.Size([7]) + assert torch_geometric_data.eyeSight.dtype == torch.float32 + for i in range(7): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]]['eyeSight'] - \ + torch_geometric_data.eyeSight[i].item() < 1e-6 + + assert torch_geometric_data.edge_index.shape == torch.Size([2, 14]) + for i in range(14): + src, dst = torch_geometric_data.edge_index[0][i].item( + ), torch_geometric_data.edge_index[1][i].item() + assert src in pos_to_idx + assert dst in pos_to_idx + assert src != dst + assert pos_to_idx[dst] in TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx[src]] + + +def test_to_torch_geometric_heterogeneous_graph(establish_connection): + conn, _ = establish_connection + query = "MATCH (p:person)-[r1:knows]->(q:person)-[r2:workAt]->(o:organisation) RETURN p, q, o, r1, r2" + + res = conn.execute(query) + with warnings.catch_warnings(record=True) as ws: + torch_geometric_data, pos_to_idx = res.get_as_torch_geometric() + + assert len(ws) == 9 + warnings_ground_truth = set([ + "Property organisation.name of type STRING is not supported by torch_geometric. The property is ignored.", + "Property person.courseScoresPerTerm cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored.", + "Property person.lastJobDuration of type INTERVAL is not supported by torch_geometric. The property is ignored.", + "Property person.registerTime of type TIMESTAMP is not supported by torch_geometric. The property is ignored.", + "Property person.birthdate of type DATE is not supported by torch_geometric. The property is ignored.", + "Property person.fName of type STRING is not supported by torch_geometric. The property is ignored.", + "Property organisation.history of type STRING is not supported by torch_geometric. The property is ignored.", + "Property person.usedNames of type STRING is not supported by torch_geometric. The property is ignored.", + "Property organisation.licenseValidInterval of type INTERVAL is not supported by torch_geometric. The property is ignored.", + ]) + + for w in ws: + assert str(w.message) in warnings_ground_truth + + assert torch_geometric_data['person'].gender.shape == torch.Size([4]) + assert torch_geometric_data['person'].gender.dtype == torch.int64 + for i in range(4): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + ]['gender'] == torch_geometric_data['person'].gender[i].item() + + assert torch_geometric_data['person'].isStudent.shape == torch.Size([4]) + assert torch_geometric_data['person'].isStudent.dtype == torch.bool + for i in range(4): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + ]['isStudent'] == torch_geometric_data['person'].isStudent[i].item() + + assert torch_geometric_data['person'].isWorker.shape == torch.Size([4]) + assert torch_geometric_data['person'].isWorker.dtype == torch.bool + for i in range(4): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + ]['isWorker'] == torch_geometric_data['person'].isWorker[i].item() + + assert torch_geometric_data['person'].age.shape == torch.Size([4]) + assert torch_geometric_data['person'].age.dtype == torch.int64 + for i in range(4): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + ]['age'] == torch_geometric_data['person'].age[i].item() + + assert torch_geometric_data['person'].eyeSight.shape == torch.Size([4]) + assert torch_geometric_data['person'].eyeSight.dtype == torch.float32 + for i in range(4): + assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i]]['eyeSight'] - \ + torch_geometric_data['person'].eyeSight[i].item() < 1e-6 + + assert torch_geometric_data['person', 'person'].edge_index.shape == torch.Size([ + 2, 6]) + for i in range(3): + src, dst = torch_geometric_data['person', 'person'].edge_index[0][i].item( + ), torch_geometric_data['person', 'person'].edge_index[1][i].item() + assert src in pos_to_idx['person'] + assert dst in pos_to_idx['person'] + assert src != dst + assert pos_to_idx['person'][dst] in TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx['person'][src]] + + assert torch_geometric_data['organisation'].orgCode.shape == torch.Size([ + 2]) + assert torch_geometric_data['organisation'].orgCode.dtype == torch.int64 + for i in range(2): + assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + ]['orgCode'] == torch_geometric_data['organisation'].orgCode[i].item() + + assert torch_geometric_data['organisation'].mark.shape == torch.Size([2]) + assert torch_geometric_data['organisation'].mark.dtype == torch.float32 + for i in range(2): + assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + ]['mark'] - torch_geometric_data['organisation'].mark[i].item() < 1e-6 + + assert torch_geometric_data['organisation'].score.shape == torch.Size([2]) + assert torch_geometric_data['organisation'].score.dtype == torch.int64 + for i in range(2): + assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + ]['score'] - torch_geometric_data['organisation'].score[i].item() < 1e-6 + + assert torch_geometric_data['organisation'].rating.shape == torch.Size([2]) + assert torch_geometric_data['organisation'].rating.dtype == torch.float32 + for i in range(2): + assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + ]['rating'] - torch_geometric_data['organisation'].rating[i].item() < 1e-6 + + assert torch_geometric_data['person', 'organisation'].edge_index.shape == torch.Size([ + 2, 2]) + for i in range(2): + src, dst = torch_geometric_data['person', 'organisation'].edge_index[0][i].item( + ), torch_geometric_data['person', 'organisation'].edge_index[1][i].item() + assert src in pos_to_idx['person'] + assert dst in pos_to_idx['organisation'] + assert src != dst + assert pos_to_idx['organisation'][dst] in TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]] + + +def test_to_torch_geometric_multi_dimensonal_lists(establish_connection): + conn, _ = establish_connection + query = "MATCH (t:tensor) RETURN t" + + res = conn.execute(query) + with warnings.catch_warnings(record=True) as ws: + torch_geometric_data, pos_to_idx = res.get_as_torch_geometric() + assert len(ws) == 1 + assert str(ws[0].message) == "Property tensor.oneDimInt has a null value. torch_geometric does not support null values. The property is ignored." + + bool_list = [] + float_list = [] + int_list = [] + + for i in range(len(pos_to_idx)): + idx = pos_to_idx[i] + bool_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['boolTensor']) + float_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['doubleTensor']) + int_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['intTensor']) + + bool_tensor = torch.tensor(bool_list, dtype=torch.bool) + float_tensor = torch.tensor(float_list, dtype=torch.float32) + int_tensor = torch.tensor(int_list, dtype=torch.int64) + + assert torch_geometric_data.boolTensor.shape == bool_tensor.shape + assert torch_geometric_data.boolTensor.dtype == bool_tensor.dtype + assert torch.all(torch_geometric_data.boolTensor == bool_tensor) + + assert torch_geometric_data.doubleTensor.shape == float_tensor.shape + assert torch_geometric_data.doubleTensor.dtype == float_tensor.dtype + assert torch.all(torch_geometric_data.doubleTensor == float_tensor) + + assert torch_geometric_data.intTensor.shape == int_tensor.shape + assert torch_geometric_data.intTensor.dtype == int_tensor.dtype + assert torch.all(torch_geometric_data.intTensor == int_tensor)