Skip to content

Commit

Permalink
pyg loader with graphscope remote backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Yi-Eaaa committed Oct 25, 2024
1 parent e10b6c8 commit bba3a8d
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 1,994 deletions.
1 change: 0 additions & 1 deletion python/graphscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from graphscope.client.session import gremlin
from graphscope.client.session import has_default_session
from graphscope.client.session import interactive
from graphscope.client.session import pyg_remote_backend
from graphscope.client.session import session
from graphscope.client.session import set_option
from graphscope.framework.errors import *
Expand Down
116 changes: 14 additions & 102 deletions python/graphscope/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,8 +1331,11 @@ def graphlearn_torch(
num_clients=1,
manifest_path=None,
client_folder_path="./",
is_pyg_remote_backend=False,
):
from graphscope.learning.gl_torch_graph import GLTorchGraph
from graphscope.learning.gs_feature_store import GsFeatureStore
from graphscope.learning.gs_graph_store import GsGraphStore
from graphscope.learning.utils import fill_params_in_yaml
from graphscope.learning.utils import read_folder_files_content

Expand Down Expand Up @@ -1380,80 +1383,17 @@ def graphlearn_torch(
g = GLTorchGraph(endpoints)
self._learning_instance_dict[graph.vineyard_id] = g
graph._attach_learning_instance(g)
return g

def pyg_remote_backend(
self,
graph,
edges,
edge_weights=None,
node_features=None,
edge_features=None,
node_labels=None,
edge_dir="out",
random_node_split=None,
num_clients=1,
manifest_path=None,
client_folder_path="./",
):
from graphscope.learning.gl_torch_graph import GLTorchGraph
from graphscope.learning.gs_feature_store import GsFeatureStore
from graphscope.learning.gs_graph_store import GsGraphStore
from graphscope.learning.utils import fill_params_in_yaml
from graphscope.learning.utils import read_folder_files_content
if is_pyg_remote_backend:
feature_store = GsFeatureStore(config)
graph_store = GsGraphStore(config)
self._learning_instance_dict[graph.vineyard_id] = feature_store
self._learning_instance_dict[graph.vineyard_id] = graph_store
graph._attach_learning_instance(feature_store)
graph._attach_learning_instance(graph_store)
return g, feature_store, graph_store

handle = {
"vineyard_socket": self._engine_config["vineyard_socket"],
"vineyard_id": graph.vineyard_id,
"fragments": graph.fragments,
"num_servers": len(graph.fragments),
"num_clients": num_clients,
}
manifest_params = {
"NUM_CLIENT_NODES": handle["num_clients"],
"NUM_SERVER_NODES": handle["num_servers"],
"NUM_WORKER_REPLICAS": handle["num_clients"] - 1,
}
if manifest_path is not None:
handle["manifest"] = fill_params_in_yaml(manifest_path, manifest_params)
if client_folder_path is not None:
handle["client_content"] = read_folder_files_content(client_folder_path)

handle = base64.b64encode(
json.dumps(handle).encode("utf-8", errors="ignore")
).decode("utf-8", errors="ignore")
config = {
"edges": edges,
"edge_weights": edge_weights,
"node_features": node_features,
"edge_features": edge_features,
"node_labels": node_labels,
"edge_dir": edge_dir,
"random_node_split": random_node_split,
}
GLTorchGraph.check_params(graph.schema, config)
config = GLTorchGraph.transform_config(config)
config = base64.b64encode(
json.dumps(config).encode("utf-8", errors="ignore")
).decode("utf-8", errors="ignore")
handle, config, endpoints = self._grpc_client.create_learning_instance(
graph.vineyard_id,
handle,
config,
message_pb2.LearningBackend.GRAPHLEARN_TORCH,
)

feature_store = GsFeatureStore(
handle=handle, config=config, endpoints=endpoints, graph=graph
)
graph_store = GsGraphStore(
handle=handle, config=config, endpoints=endpoints, graph=graph
)

learning_instance = tuple([feature_store, graph_store])
self._learning_instance_dict[graph.vineyard_id] = learning_instance
graph._attach_learning_instance(learning_instance)
return feature_store, graph_store
return g

def nx(self):
if not self.eager():
Expand Down Expand Up @@ -1755,6 +1695,7 @@ def graphlearn_torch(
num_clients=1,
manifest_path=None,
client_folder_path="./",
is_pyg_remote_backend=False,
):
assert graph is not None, "graph cannot be None"
assert (
Expand All @@ -1772,34 +1713,5 @@ def graphlearn_torch(
num_clients,
manifest_path,
client_folder_path,
is_pyg_remote_backend,
) # pylint: disable=protected-access


def pyg_remote_backend(
graph,
edges,
edge_weights=None,
node_features=None,
edge_features=None,
node_labels=None,
edge_dir="out",
random_node_split=None,
num_clients=1,
manifest_path=None,
client_folder_path="./",
):
assert graph is not None, "graph cannot be None"
assert graph._session is not None, "The graph object is invalid"
return graph._session.pyg_remote_backend(
graph,
edges,
edge_weights,
node_features,
edge_features,
node_labels,
edge_dir,
random_node_split,
num_clients,
manifest_path,
client_folder_path,
)
Loading

0 comments on commit bba3a8d

Please sign in to comment.