From edf24580a62c4c9ba7ccb2abf35620c0093a13fa Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Tue, 16 Feb 2021 14:32:23 -0800 Subject: [PATCH] [ray_client]: Set gRPC max message size to 4GiB (#14063) * [ray_client]: Set gRPC max message size to 4GiB Change-Id: Id4d6887cdd90dd761dd25248f10f104701462667 * reduce size Change-Id: I71625ed3cffd9d8b3d7d3d7a981bb4dda00ed0a1 * Update test_basic_2.py * Update test_advanced.py * Update test_basic.py Co-authored-by: Eric Liang --- python/ray/tests/test_advanced.py | 4 ++-- python/ray/tests/test_basic.py | 3 +-- python/ray/tests/test_basic_2.py | 7 ++----- python/ray/util/client/common.py | 10 ++++++++++ python/ray/util/client/server/server.py | 8 +++++++- python/ray/util/client/worker.py | 12 ++++++++++-- 6 files changed, 32 insertions(+), 12 deletions(-) diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index 6df746fdcd911..f34fd601a80e5 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -25,7 +25,7 @@ # issue https://github.com/ray-project/ray/issues/7105 -@pytest.mark.skipif(client_test_enabled(), reason="message size") +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_internal_free(shutdown_only): ray.init(num_cpus=1) @@ -493,7 +493,7 @@ def join(self): ray.get(actor.join.remote()) == "ok" -@pytest.mark.skipif(client_test_enabled(), reason="message size") +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_wait_makes_object_local(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index f6ba1509c4022..93368d7d12ffa 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -31,7 +31,6 @@ def f(): # https://github.com/ray-project/ray/issues/7263 -@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_grpc_message_size(shutdown_only): ray.init(num_cpus=1) @@ -256,7 +255,7 @@ def foo(): assert without_options != with_options -@pytest.mark.skipif(client_test_enabled(), reason="message size") +@pytest.mark.skipif(client_test_enabled(), reason="internal api") @pytest.mark.parametrize( "ray_start_cluster_head", [{ "num_cpus": 0, diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py index 21fabc4ba55a0..db2234ed33234 100644 --- a/python/ray/tests/test_basic_2.py +++ b/python/ray/tests/test_basic_2.py @@ -193,7 +193,6 @@ def j(): assert ray.get(ray.get(h.remote(i))) == i -@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_call_matrix(shutdown_only): ray.init(object_store_memory=1000 * 1024 * 1024) @@ -319,7 +318,6 @@ def runner(f): assert delta < 10, "did not skip slow value" -@pytest.mark.skipif(client_test_enabled(), reason="message size") @pytest.mark.parametrize( "ray_start_cluster", [{ "num_cpus": 1, @@ -340,7 +338,7 @@ def g(x): assert ray.get(x) == 100 -@pytest.mark.skipif(client_test_enabled(), reason="message size") +@pytest.mark.skipif(client_test_enabled(), reason="init issue") def test_system_config_when_connecting(ray_start_cluster): config = {"object_timeout_milliseconds": 200} cluster = ray.cluster_utils.Cluster() @@ -463,8 +461,7 @@ def f(self, x): assert ray.get(obj_ref) == 2 -@pytest.mark.skipif( - client_test_enabled(), reason="internal api and message size") +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_actor_large_objects(ray_start_regular_shared): @ray.remote class Actor: diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index cd4c57f972750..35f998db0b079 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -13,6 +13,16 @@ from typing import Optional from typing import Union +# TODO: Instead of just making the max message size large, the right thing to +# do is to split up the bytes representation of serialized data into multiple +# messages and reconstruct them on either end. That said, since clients are +# drivers and really just feed initial things in and final results out, (when +# not going to S3 or similar) then a large limit will suffice for many use +# cases. +# +# Currently, this is 2GiB, the max for a signed int. +GRPC_MAX_MESSAGE_SIZE = (2 * 1024 * 1024 * 1024) - 1 + class ClientBaseRef: def __init__(self, id: bytes): diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 571646bd0f86f..f19fc92bb8d73 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -18,6 +18,7 @@ import time import inspect import json +from ray.util.client.common import GRPC_MAX_MESSAGE_SIZE from ray.util.client.server.server_pickler import convert_from_arg from ray.util.client.server.server_pickler import dumps_from_server from ray.util.client.server.server_pickler import loads_from_client @@ -437,7 +438,12 @@ def default_connect_handler(): return ray.init() ray_connect_handler = ray_connect_handler or default_connect_handler - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE), + ]) task_servicer = RayletServicer() data_servicer = DataServicer( task_servicer, ray_connect_handler=ray_connect_handler) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index fbdcfe13c2a48..39965c2a77a44 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -29,6 +29,7 @@ from ray.util.client.common import ClientRemoteFunc from ray.util.client.common import ClientActorRef from ray.util.client.common import ClientObjectRef +from ray.util.client.common import GRPC_MAX_MESSAGE_SIZE from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient @@ -72,11 +73,18 @@ def __init__(self, self._conn_state = grpc.ChannelConnectivity.IDLE self._client_id = make_client_id() self._converted: Dict[str, ClientStub] = {} + + grpc_options = [ + ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE), + ] if secure: credentials = grpc.ssl_channel_credentials() - self.channel = grpc.secure_channel(conn_str, credentials) + self.channel = grpc.secure_channel( + conn_str, credentials, options=grpc_options) else: - self.channel = grpc.insecure_channel(conn_str) + self.channel = grpc.insecure_channel( + conn_str, options=grpc_options) self.channel.subscribe(self._on_channel_state_change)