Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ray_client] close ray connection upon client deactivation #13919

Merged
merged 18 commits into from
Feb 7, 2021
12 changes: 10 additions & 2 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import grpc
import sys

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from threading import Lock

import ray.core.generated.ray_client_pb2 as ray_client_pb2
Expand All @@ -17,10 +17,12 @@


class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
def __init__(self, basic_service: "RayletServicer",
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
ray_connect_handler: Callable):
self.basic_service = basic_service
self._clients_lock = Lock()
self._num_clients = 0 # guarded by self._clients_lock
self.ray_connect_handler = ray_connect_handler

def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
Expand All @@ -31,6 +33,8 @@ def Datapath(self, request_iterator, context):
logger.info(f"New data connection from client {client_id}")
try:
with self._clients_lock:
if self._num_clients == 0 and not ray.is_initialized():
self.ray_connect_handler()
self._num_clients += 1
for req in request_iterator:
resp = None
Expand Down Expand Up @@ -63,9 +67,13 @@ def Datapath(self, request_iterator, context):
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)

with self._clients_lock:
self._num_clients -= 1

if self._num_clients == 0:
ray.shutdown()

def _build_connection_response(self):
with self._clients_lock:
cur_num_clients = self._num_clients
Expand Down
26 changes: 15 additions & 11 deletions python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,11 @@ def __getattr__(self, attr):
return getattr(self.grpc_server, attr)


def serve(connection_str):
def serve(connection_str, ray_connect_handler):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer()
data_servicer = DataServicer(task_servicer)
data_servicer = DataServicer(
task_servicer, ray_connect_handler=ray_connect_handler)
logs_servicer = LogstreamServicer()
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
Expand Down Expand Up @@ -477,18 +478,21 @@ def main():
help="Password for connecting to Redis")
args = parser.parse_args()
logging.basicConfig(level="INFO")
if args.redis_address:
if args.redis_password:
ray.init(
address=args.redis_address,
_redis_password=args.redis_password)

def ray_connect_handler():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this top-level and pass in redis_address/redis_password?

if args.redis_address:
if args.redis_password:
ray.init(
address=args.redis_address,
_redis_password=args.redis_password)
else:
ray.init(address=args.redis_address)
else:
ray.init(address=args.redis_address)
else:
ray.init()
ray.init()

hostport = "%s:%d" % (args.host, args.port)
logger.info(f"Starting Ray Client server on {hostport}")
server = serve(hostport)
server = serve(hostport, ray_connect_handler)
try:
while True:
time.sleep(1000)
Expand Down
28 changes: 14 additions & 14 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ def __init__(self,
# RayletDriverStub, allowing for unary requests.
self.server = ray_client_pb2_grpc.RayletDriverStub(
self.channel)
# Now the HTTP2 channel is ready, or proxied, but the
# servicer may not be ready. Call is_initialized() and if
# it throws, the servicer is not ready. On success, the
# `ray_ready` result is checked.
ray_ready = self.is_initialized()
if ray_ready:
# Ray is ready! Break out of the retry loop
break
# Ray is not ready yet, wait a timeout
time.sleep(timeout)
# # Now the HTTP2 channel is ready, or proxied, but the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what problem was this causing? Shouldn't is_initialized still work as before after the client connects?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, the data connection isn't added until later. How about we add a new dummy ping operation here then instead of is_initialized()? It can be anything really, is_initialized was just used as a convenient no_op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_initialized() also had the advantage of checking if ray was initialized, which is kinda useful here. What you might want to do though is do the data connection earlier. you might need to wait for /it/ to return ready (meaning the send/receive loop thread pair is kicked off and sets a flag, or even better, returns its client_id) -- but in general, yeah, some op here is useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK i did something, take a look! happy to hear any suggestions

# # servicer may not be ready. Call is_initialized() and if
# # it throws, the servicer is not ready. On success, the
# # `ray_ready` result is checked.
# ray_ready = self.is_initialized()
# if ray_ready:
# # Ray is ready! Break out of the retry loop
# break
# # Ray is not ready yet, wait a timeout
# time.sleep(timeout)
except grpc.FutureTimeoutError:
logger.info(
f"Couldn't connect channel in {timeout} seconds, retrying")
Expand All @@ -120,10 +120,10 @@ def __init__(self,
f"retry in {timeout}s...")
timeout = backoff(timeout)

# If we made it through the loop without ray_ready it means we've used
# up our retries and should error back to the user.
if not ray_ready:
raise ConnectionError("ray client connection timeout")
# # If we made it through the loop without ray_ready it means we've used
# # up our retries and should error back to the user.
# if not ray_ready:
# raise ConnectionError("ray client connection timeout")

# Initialize the streams to finish protocol negotiation.
self.data_client = DataClient(self.channel, self._client_id,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@barakmich I just moved this into the try-except

Expand Down