Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add document for PyG remote backend #1525

Merged
merged 1 commit into from
May 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions tools/python_api/src_py/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Database:
Set the logging level.

get_torch_geometric_remote_backend(num_threads)
Get the torch_geometric remote backend.
Use the database as the remote backend for torch_geometric.

"""

Expand Down Expand Up @@ -81,10 +81,33 @@ def set_logging_level(self, level):
self._database.set_logging_level(level)

def get_torch_geometric_remote_backend(self, num_threads=None):
from .torch_geometric_feature_store import KuzuFeatureStore
from .torch_geometric_graph_store import KuzuGraphStore
"""
Get the torch_geometric remote backend.
Use the database as the remote backend for torch_geometric.

For the interface of the remote backend, please refer to
https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html.
The current implementation is read-only and does not support edge
features. The IDs of the nodes are based on the internal IDs (i.e., node
offsets). For the remote node IDs to be consistent with the positions in
the output tensors, please ensure that no deletion has been performed
on the node tables.

The remote backend can also be plugged into the data loader of
torch_geometric, which is useful for mini-batch training. For example:

.. code-block:: python

loader_kuzu = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
batch_size=LOADER_BATCH_SIZE,
input_nodes=('paper', input_nodes),
num_workers=4,
filter_per_worker=False,
)

Please note that the database instance is not fork-safe, so if more than
one worker is used, `filter_per_worker` must be set to False.

Parameters
----------
Expand All @@ -99,6 +122,8 @@ def get_torch_geometric_remote_backend(self, num_threads=None):
graph_store : KuzuGraphStore
Graph store compatible with torch_geometric.
"""
from .torch_geometric_feature_store import KuzuFeatureStore
from .torch_geometric_graph_store import KuzuGraphStore
return KuzuFeatureStore(self, num_threads), KuzuGraphStore(self, num_threads)

def _scan_node_table(self, table_name, prop_name, prop_type, dim, indices, num_threads):
Expand Down