Skip to content

Commit

Permalink
Rename socket_timeout to timeout for compatibility
Browse files Browse the repository at this point in the history
Signed-off-by: aiudirog <aiudirog@gmail.com>
  • Loading branch information
aiudirog committed Jan 19, 2020
1 parent 82d5a79 commit 5bd4ef4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
59 changes: 57 additions & 2 deletions tests/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# import uvloop
import threading
import random
from unittest.mock import patch

# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Expand Down Expand Up @@ -151,7 +152,7 @@ async def client(self, timeout: int = 3000000):
addressbook.AddressBookService,
trans_factory=self.TRANSPORT_FACTORY,
proto_factory=self.PROTOCOL_FACTORY,
socket_timeout=timeout,
timeout=timeout,
**self.client_kwargs(),
)

Expand Down Expand Up @@ -246,7 +247,7 @@ async def client_with_url(self, timeout: int = 3000):
addressbook.AddressBookService,
trans_factory=self.TRANSPORT_FACTORY,
proto_factory=self.PROTOCOL_FACTORY,
socket_timeout=timeout,
timeout=timeout,
**kw,
)

Expand Down Expand Up @@ -304,3 +305,57 @@ async def test_client_connect_timeout():
connect_timeout=1000
)
await c.hello('test')


class TestDeprecatedTimeoutKwarg:
"""
Replace TAsyncSocket with a Mock object that raises a RuntimeError
when called. This allows us to check that timeout vs. socket_timeout
arguments are properly handled without actually creating the client.
This class should be removed when the socket_timeout argument is removed.
"""
def setup(self):
# Create and apply a fresh patch for each test.
self.async_sock = patch(
'thriftpy2.contrib.aio.rpc.TAsyncSocket',
side_effect=RuntimeError,
).__enter__()

def teardown_(self):
self.async_sock.__exit__() # Clean up patch

@pytest.mark.asyncio
async def test_no_timeout_given(self):
await self._make_client()
assert self._given_timeout() == 3000 # Default value

@pytest.mark.asyncio
async def test_timeout_given(self):
await self._make_client(timeout=1234)
assert self._given_timeout() == 1234

@pytest.mark.asyncio
async def test_socket_timeout_given(self):
await self._make_client(warning=DeprecationWarning, socket_timeout=555)
assert self._given_timeout() == 555

@staticmethod
async def _make_client(warning=None, **kwargs):
"""
Helper method to create the client and check that the proper warning
is emitted (if any) and that the patch is properly applied by
consuming the RuntimeError.
"""
with pytest.warns(warning),\
pytest.raises(RuntimeError): # Consume error
await make_aio_client(addressbook.AddressBookService, **kwargs)

def _given_timeout(self):
"""Get the timeout provided to TAsyncSocket."""
try:
self.async_sock.assert_called_once()
except AttributeError: # Python 3.5
assert self.async_sock.call_count == 1
_args, kwargs = self.async_sock.call_args
return kwargs['socket_timeout']
25 changes: 16 additions & 9 deletions thriftpy2/contrib/aio/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,38 @@
from .server import TAsyncServer



@asyncio.coroutine
def make_client(service, host='localhost', port=9090, unix_socket=None,
proto_factory=TAsyncBinaryProtocolFactory(),
trans_factory=TAsyncBufferedTransportFactory(),
socket_timeout=3000, connect_timeout=None,
timeout=3000, connect_timeout=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url=''):
validate=True, url='',
socket_timeout=None):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
"Please use 'timeout' instead.",
DeprecationWarning,
)
timeout = socket_timeout
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
socket_timeout=timeout)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=socket_timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
else:
raise ValueError("Either host/port or unix_socket or url must be provided.")

Expand Down

0 comments on commit 5bd4ef4

Please sign in to comment.