Skip to content

Commit

Permalink
[ray_client]: Set gRPC max message size to 4GiB (ray-project#14063)
Browse files Browse the repository at this point in the history
* [ray_client]: Set gRPC max message size to 4GiB

Change-Id: Id4d6887cdd90dd761dd25248f10f104701462667

* reduce size

Change-Id: I71625ed3cffd9d8b3d7d3d7a981bb4dda00ed0a1

* Update test_basic_2.py

* Update test_advanced.py

* Update test_basic.py

Co-authored-by: Eric Liang <ekhliang@gmail.com>
  • Loading branch information
barakmich and ericl authored Feb 16, 2021
1 parent 33316d4 commit edf2458
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 12 deletions.
4 changes: 2 additions & 2 deletions python/ray/tests/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


# issue https://github.com/ray-project/ray/issues/7105
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_internal_free(shutdown_only):
ray.init(num_cpus=1)

Expand Down Expand Up @@ -493,7 +493,7 @@ def join(self):
ray.get(actor.join.remote()) == "ok"


@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_wait_makes_object_local(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
Expand Down
3 changes: 1 addition & 2 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def f():


# https://github.com/ray-project/ray/issues/7263
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_grpc_message_size(shutdown_only):
ray.init(num_cpus=1)

Expand Down Expand Up @@ -256,7 +255,7 @@ def foo():
assert without_options != with_options


@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"num_cpus": 0,
Expand Down
7 changes: 2 additions & 5 deletions python/ray/tests/test_basic_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def j():
assert ray.get(ray.get(h.remote(i))) == i


@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)

Expand Down Expand Up @@ -319,7 +318,6 @@ def runner(f):
assert delta < 10, "did not skip slow value"


@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
Expand All @@ -340,7 +338,7 @@ def g(x):
assert ray.get(x) == 100


@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.skipif(client_test_enabled(), reason="init issue")
def test_system_config_when_connecting(ray_start_cluster):
config = {"object_timeout_milliseconds": 200}
cluster = ray.cluster_utils.Cluster()
Expand Down Expand Up @@ -463,8 +461,7 @@ def f(self, x):
assert ray.get(obj_ref) == 2


@pytest.mark.skipif(
client_test_enabled(), reason="internal api and message size")
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_large_objects(ray_start_regular_shared):
@ray.remote
class Actor:
Expand Down
10 changes: 10 additions & 0 deletions python/ray/util/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
from typing import Optional
from typing import Union

# TODO: Instead of just making the max message size large, the right thing to
# do is to split up the bytes representation of serialized data into multiple
# messages and reconstruct them on either end. That said, since clients are
# drivers and really just feed initial things in and final results out, (when
# not going to S3 or similar) then a large limit will suffice for many use
# cases.
#
# Currently, this is 2GiB, the max for a signed int.
GRPC_MAX_MESSAGE_SIZE = (2 * 1024 * 1024 * 1024) - 1


class ClientBaseRef:
def __init__(self, id: bytes):
Expand Down
8 changes: 7 additions & 1 deletion python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time
import inspect
import json
from ray.util.client.common import GRPC_MAX_MESSAGE_SIZE
from ray.util.client.server.server_pickler import convert_from_arg
from ray.util.client.server.server_pickler import dumps_from_server
from ray.util.client.server.server_pickler import loads_from_client
Expand Down Expand Up @@ -437,7 +438,12 @@ def default_connect_handler():
return ray.init()

ray_connect_handler = ray_connect_handler or default_connect_handler
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
])
task_servicer = RayletServicer()
data_servicer = DataServicer(
task_servicer, ray_connect_handler=ray_connect_handler)
Expand Down
12 changes: 10 additions & 2 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ray.util.client.common import ClientRemoteFunc
from ray.util.client.common import ClientActorRef
from ray.util.client.common import ClientObjectRef
from ray.util.client.common import GRPC_MAX_MESSAGE_SIZE
from ray.util.client.dataclient import DataClient
from ray.util.client.logsclient import LogstreamClient

Expand Down Expand Up @@ -72,11 +73,18 @@ def __init__(self,
self._conn_state = grpc.ChannelConnectivity.IDLE
self._client_id = make_client_id()
self._converted: Dict[str, ClientStub] = {}

grpc_options = [
("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
]
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
self.channel = grpc.secure_channel(
conn_str, credentials, options=grpc_options)
else:
self.channel = grpc.insecure_channel(conn_str)
self.channel = grpc.insecure_channel(
conn_str, options=grpc_options)

self.channel.subscribe(self._on_channel_state_change)

Expand Down

0 comments on commit edf2458

Please sign in to comment.