diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index e88cdef7bd..f122faf2e7 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -278,12 +278,14 @@ async def test_tls_specific(tcp): """ async def handle_comm(comm): - assert comm.peer_address.startswith("tls://" + host) - check_tls_extra(comm.extra_info) - msg = await comm.read() - msg["op"] = "pong" - await comm.write(msg) - await comm.close() + try: + assert comm.peer_address.startswith("tls://" + host) + check_tls_extra(comm.extra_info) + msg = await comm.read() + msg["op"] = "pong" + await comm.write(msg) + finally: + await comm.close() server_ctx = get_server_ssl_context() client_ctx = get_client_ssl_context() @@ -298,15 +300,17 @@ async def handle_comm(comm): async def client_communicate(key, delay=0): addr = "%s:%d" % (host, port) comm = await connect(listener.contact_address, ssl_context=client_ctx) - assert comm.peer_address == "tls://" + addr - check_tls_extra(comm.extra_info) - await comm.write({"op": "ping", "data": key}) - if delay: - await asyncio.sleep(delay) - msg = await comm.read() - assert msg == {"op": "pong", "data": key} - l.append(key) - await comm.close() + try: + assert comm.peer_address == "tls://" + addr + check_tls_extra(comm.extra_info) + await comm.write({"op": "ping", "data": key}) + if delay: + await asyncio.sleep(delay) + msg = await comm.read() + assert msg == {"op": "pong", "data": key} + l.append(key) + finally: + await comm.close() await client_communicate(key=1234) @@ -370,13 +374,15 @@ async def check_inproc_specific(run_client): N_MSGS = 3 async def handle_comm(comm): - assert comm.peer_address.startswith("inproc://" + addr_head) - client_addresses.add(comm.peer_address) - for i in range(N_MSGS): - msg = await comm.read() - msg["op"] = "pong" - await comm.write(msg) - await comm.close() + try: + assert comm.peer_address.startswith("inproc://" + addr_head) + client_addresses.add(comm.peer_address) + for i in range(N_MSGS): + msg = await comm.read() + msg["op"] = "pong" + await comm.write(msg) + finally: + await comm.close() async with inproc.InProcListener(listener_addr, handle_comm) as listener: assert ( @@ -389,17 +395,19 @@ async def handle_comm(comm): async def client_communicate(key, delay=0): comm = await connect(listener.contact_address) - assert comm.peer_address == "inproc://" + listener_addr - for i in range(N_MSGS): - await comm.write({"op": "ping", "data": key}) - if delay: - await asyncio.sleep(delay) - msg = await comm.read() - assert msg == {"op": "pong", "data": key} - l.append(key) - with pytest.raises(CommClosedError): - await comm.read() - await comm.close() + try: + assert comm.peer_address == "inproc://" + listener_addr + for i in range(N_MSGS): + await comm.write({"op": "ping", "data": key}) + if delay: + await asyncio.sleep(delay) + msg = await comm.read() + assert msg == {"op": "pong", "data": key} + l.append(key) + with pytest.raises(CommClosedError): + await comm.read() + finally: + await comm.close() client_communicate = partial(run_client, client_communicate) @@ -499,18 +507,19 @@ async def check_client_server( """ async def handle_comm(comm): - scheme, loc = parse_address(comm.peer_address) - assert scheme == bound_scheme - - msg = await comm.read() - assert msg["op"] == "ping" - msg["op"] = "pong" - await comm.write(msg) + try: + scheme, loc = parse_address(comm.peer_address) + assert scheme == bound_scheme - msg = await comm.read() - assert msg["op"] == "foobar" + msg = await comm.read() + assert msg["op"] == "ping" + msg["op"] = "pong" + await comm.write(msg) - await comm.close() + msg = await comm.read() + assert msg["op"] == "foobar" + finally: + await comm.close() # Arbitrary connection args should be ignored listen_args = listen_args or {"xxx": "bar"} @@ -541,16 +550,18 @@ async def handle_comm(comm): async def client_communicate(key, delay=0): comm = await connect(listener.contact_address, **connect_args) - assert comm.peer_address == listener.contact_address + try: + assert comm.peer_address == listener.contact_address - await comm.write({"op": "ping", "data": key}) - await comm.write({"op": "foobar"}) - if delay: - await asyncio.sleep(delay) - msg = await comm.read() - assert msg == {"op": "pong", "data": key} - l.append(key) - await comm.close() + await comm.write({"op": "ping", "data": key}) + await comm.write({"op": "foobar"}) + if delay: + await asyncio.sleep(delay) + msg = await comm.read() + assert msg == {"op": "pong", "data": key} + l.append(key) + finally: + await comm.close() await client_communicate(key=1234) @@ -692,9 +703,11 @@ async def test_tls_reject_certificate(tcp): bad_serv_ctx = get_server_ssl_context(*bad_cert_key) async def handle_comm(comm): - scheme, loc = parse_address(comm.peer_address) - assert scheme == "tls" - await comm.close() + try: + scheme, loc = parse_address(comm.peer_address) + assert scheme == "tls" + finally: + await comm.close() # Listener refuses a connector not signed by the CA listener = await listen("tls://", handle_comm, ssl_context=serv_ctx) @@ -1031,7 +1044,8 @@ async def handle_comm(comm): q.put_nowait(exc) else: q.put_nowait(msg) - await comm.close() + finally: + await comm.close() async with listen(addr, handle_comm, deserialize=deserialize) as listener: comm = await connect(listener.contact_address) @@ -1049,16 +1063,20 @@ async def check_connector_deserialize(addr, deserialize, in_value, check_out): done = asyncio.Event() async def handle_comm(comm): - await comm.write(in_value) - await done.wait() - await comm.close() + try: + await comm.write(in_value) + await done.wait() + finally: + await comm.close() async with listen(addr, handle_comm) as listener: comm = await connect(listener.contact_address, deserialize=deserialize) - out_value = await comm.read() - done.set() - await comm.close() + try: + out_value = await comm.read() + done.set() + finally: + await comm.close() check_out(out_value)