Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ray_client] close ray connection upon client deactivation #13919

Merged
merged 18 commits into from
Feb 7, 2021
1 change: 1 addition & 0 deletions ci/travis/ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ test_python() {
-python/ray/tests:test_basic_3 # timeout
-python/ray/tests:test_basic_3_client_mode
-python/ray/tests:test_cli
-python/ray/tests:test_client_init # timeout
-python/ray/tests:test_failure
-python/ray/tests:test_global_gc
-python/ray/tests:test_job
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ py_test_module_list(
"test_basic_3.py",
"test_cancel.py",
"test_cli.py",
"test_client.py",
"test_client_init.py",
"test_component_failures_2.py",
"test_component_failures_3.py",
"test_error_ray_not_initialized.py",
Expand Down Expand Up @@ -80,9 +82,7 @@ py_test_module_list(
"test_asyncio.py",
"test_autoscaler.py",
"test_autoscaler_yaml.py",
"test_client_init.py",
"test_client_metadata.py",
"test_client.py",
"test_client_references.py",
"test_client_terminate.py",
"test_command_runner.py",
Expand Down
260 changes: 138 additions & 122 deletions python/ray/tests/test_client_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,130 +38,146 @@ def get(self):
return self.val


def test_basic_preregister():
@pytest.fixture
def init_and_serve():
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
yield server_handle
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)


@pytest.fixture
def init_and_serve_lazy():
cluster = ray.cluster_utils.Cluster()
cluster.add_node(num_cpus=1, num_gpus=0)
address = cluster.address

def connect():
ray.init(address=address)

server_handle = ray_client_server.serve("localhost:50051", connect)
yield server_handle
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)


def test_basic_preregister(init_and_serve):
from ray.util.client import ray
server, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
finally:
ray.disconnect()
ray_client_server.shutdown_with_server(server)
time.sleep(2)


def test_num_clients():
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()


def test_num_clients(init_and_serve_lazy):
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
server = server_handle.grpc_server
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3

# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
assert isinstance(info3["protocol_version"], str), info3
api3.disconnect()
finally:
ray_client_server.shutdown_with_server(server)
time.sleep(2)


def test_python_version():

server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)


def test_protocol_version():
def get_job_id(api):
return api.get_runtime_context().worker.current_job_id

server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join(
[str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this stuff into the fixture

api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
job_id_1 = get_job_id(api1)
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
job_id_2 = get_job_id(api2)
assert info2["num_clients"] == 2, info2

assert job_id_1 == job_id_2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
job_id_3 = get_job_id(api3)
assert info3["num_clients"] == 1, info3
assert job_id_1 != job_id_3

# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
assert isinstance(info3["protocol_version"], str), info3
api3.disconnect()


def test_python_version(init_and_serve):
server_handle = init_and_serve
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()


def test_protocol_version(init_and_serve):
server_handle = init_and_serve
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()


if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
4 changes: 2 additions & 2 deletions python/ray/tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self):
assert len(actor_table) == 1

job_table = ray.jobs()
assert len(job_table) == 3 # dash, ray client server
assert len(job_table) == 2 # dash

# Kill the driver process.
p.kill()
Expand Down Expand Up @@ -79,7 +79,7 @@ def value(self):
assert len(actor_table) == 1

job_table = ray.jobs()
assert len(job_table) == 3 # dash, ray client server
assert len(job_table) == 2 # dash

# Kill the driver process.
p.kill()
Expand Down
15 changes: 13 additions & 2 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import grpc
import sys

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from threading import Lock

import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import CURRENT_PROTOCOL_VERSION
from ray._private.client_mode_hook import disable_client_hook

if TYPE_CHECKING:
from ray.util.client.server.server import RayletServicer
Expand All @@ -17,10 +18,12 @@


class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
def __init__(self, basic_service: "RayletServicer",
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
ray_connect_handler: Callable):
self.basic_service = basic_service
self._clients_lock = Lock()
self._num_clients = 0 # guarded by self._clients_lock
self.ray_connect_handler = ray_connect_handler

def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
Expand All @@ -31,6 +34,9 @@ def Datapath(self, request_iterator, context):
logger.info(f"New data connection from client {client_id}")
try:
with self._clients_lock:
with disable_client_hook():
if self._num_clients == 0 and not ray.is_initialized():
self.ray_connect_handler()
self._num_clients += 1
for req in request_iterator:
resp = None
Expand Down Expand Up @@ -63,9 +69,14 @@ def Datapath(self, request_iterator, context):
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)

with self._clients_lock:
self._num_clients -= 1

with disable_client_hook():
if self._num_clients == 0:
ray.shutdown()

def _build_connection_response(self):
with self._clients_lock:
cur_num_clients = self._num_clients
Expand Down
Loading