Skip to content

Commit

Permalink
Merge pull request #1487 from kuzudb/pyg-worker-shared-memory
Browse files Browse the repository at this point in the history
Fix process fork support for Python API
  • Loading branch information
mewim committed Apr 26, 2023
2 parents 5421c64 + 18528cd commit 3f5d5a6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 22 deletions.
28 changes: 23 additions & 5 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,23 @@ def __init__(self, database, num_threads=0):
"""

self.database = database
database.init_database()
self._connection = _kuzu.Connection(database._database, num_threads)
self.num_threads = num_threads
self._connection = None
self.init_connection()

def __getstate__(self):
state = {
"database": self.database,
"num_threads": self.num_threads,
"_connection": None
}
return state

def init_connection(self):
self.database.init_database()
if self._connection is None:
self._connection = _kuzu.Connection(
self.database._database, self.num_threads)

def set_max_threads_for_exec(self, num_threads):
"""
Expand All @@ -45,7 +60,7 @@ def set_max_threads_for_exec(self, num_threads):
num_threads : int
Maximum number of threads to use for executing queries.
"""

self.init_connection()
self._connection.set_max_threads_for_exec(num_threads)

def execute(self, query, parameters=[]):
Expand All @@ -66,7 +81,7 @@ def execute(self, query, parameters=[]):
QueryResult
Query result.
"""

self.init_connection()
prepared_statement = self.prepare(
query) if type(query) == str else query
return QueryResult(self,
Expand Down Expand Up @@ -96,6 +111,7 @@ def _get_node_property_names(self, table_name):
PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)"
LIST_START_SYMBOL = "["
LIST_END_SYMBOL = "]"
self.init_connection()
result_str = self._connection.get_node_property_names(
table_name)
results = {}
Expand Down Expand Up @@ -136,6 +152,7 @@ def _get_node_property_names(self, table_name):

def _get_node_table_names(self):
results = []
self.init_connection()
result_str = self._connection.get_node_table_names()
for (i, line) in enumerate(result_str.splitlines()):
# ignore first line
Expand All @@ -149,6 +166,7 @@ def _get_node_table_names(self):

def _get_rel_table_names(self):
results = []
self.init_connection()
result_str = self._connection.get_rel_table_names()
for i, line in enumerate(result_str.splitlines()):
if i == 0:
Expand Down Expand Up @@ -178,5 +196,5 @@ def set_query_timeout(self, timeout_in_ms):
timeout_in_ms : int
query timeout value in ms for executing queries.
"""

self.init_connection()
self._connection.set_query_timeout(timeout_in_ms)
2 changes: 1 addition & 1 deletion tools/python_api/src_py/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __getstate__(self):
state = {
"database_path": self.database_path,
"buffer_pool_size": self.buffer_pool_size,
"database": None
"_database": None
}
return state

Expand Down
8 changes: 0 additions & 8 deletions tools/python_api/src_py/torch_geometric_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ def __init__(self, db, num_threads):
num_threads = multiprocessing.cpu_count()
self.num_threads = num_threads

def __getstate__(self):
state = {
"connection": None,
"node_properties_cache": self.node_properties_cache,
"db": self.db.__getstate__()
}
return state

def __get_connection(self):
if not self.connection:
self.connection = Connection(self.db, self.num_threads)
Expand Down
8 changes: 0 additions & 8 deletions tools/python_api/src_py/torch_geometric_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ def __init__(self, db, num_threads):
self.num_threads = num_threads
self.__populate_edge_attrs()

def __getstate__(self):
state = {
"connection": None,
"store": self.store,
"db": self.db.__getstate__()
}
return state

@staticmethod
def key(attr: EdgeAttr) -> Tuple:
return (attr.edge_type, attr.layout.value, attr.is_sorted)
Expand Down

0 comments on commit 3f5d5a6

Please sign in to comment.