Skip to content

Commit

Permalink
Merge pull request #1525 from kuzudb/pyg-remote-backend-docs
Browse files Browse the repository at this point in the history
Add document for PyG remote backend
  • Loading branch information
mewim committed May 10, 2023
2 parents 4f02b2d + 98a13fb commit 032d5ad
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions tools/python_api/src_py/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Database:
Set the logging level.
get_torch_geometric_remote_backend(num_threads)
Get the torch_geometric remote backend.
Use the database as the remote backend for torch_geometric.
"""

Expand Down Expand Up @@ -81,10 +81,33 @@ def set_logging_level(self, level):
self._database.set_logging_level(level)

def get_torch_geometric_remote_backend(self, num_threads=None):
from .torch_geometric_feature_store import KuzuFeatureStore
from .torch_geometric_graph_store import KuzuGraphStore
"""
Get the torch_geometric remote backend.
Use the database as the remote backend for torch_geometric.
For the interface of the remote backend, please refer to
https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html.
The current implementation is read-only and does not support edge
features. The IDs of the nodes are based on the internal IDs (i.e., node
offsets). For the remote node IDs to be consistent with the positions in
the output tensors, please ensure that no deletion has been performed
on the node tables.
The remote backend can also be plugged into the data loader of
torch_geometric, which is useful for mini-batch training. For example:
.. code-block:: python
loader_kuzu = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
batch_size=LOADER_BATCH_SIZE,
input_nodes=('paper', input_nodes),
num_workers=4,
filter_per_worker=False,
)
Please note that the database instance is not fork-safe, so if more than
one worker is used, `filter_per_worker` must be set to False.
Parameters
----------
Expand All @@ -99,6 +122,8 @@ def get_torch_geometric_remote_backend(self, num_threads=None):
graph_store : KuzuGraphStore
Graph store compatible with torch_geometric.
"""
from .torch_geometric_feature_store import KuzuFeatureStore
from .torch_geometric_graph_store import KuzuGraphStore
return KuzuFeatureStore(self, num_threads), KuzuGraphStore(self, num_threads)

def _scan_node_table(self, table_name, prop_name, prop_type, dim, indices, num_threads):
Expand Down

0 comments on commit 032d5ad

Please sign in to comment.