diff --git a/locust/exception.py b/locust/exception.py index 3d6a825bff..a6925a941c 100644 --- a/locust/exception.py +++ b/locust/exception.py @@ -37,3 +37,10 @@ class RescheduleTaskImmediately(Exception): """ When raised in a Locust task, another locust task will be rescheduled immediately """ + +class RPCError(Exception): + """ + Exception that shows bad or broken network. + + When raised from zmqrpc, RPC should be reestablished. + """ \ No newline at end of file diff --git a/locust/rpc/zmqrpc.py b/locust/rpc/zmqrpc.py index ec3f1c48b0..2b89ab6728 100644 --- a/locust/rpc/zmqrpc.py +++ b/locust/rpc/zmqrpc.py @@ -1,7 +1,9 @@ import zmq.green as zmq - from .protocol import Message from locust.util.exception_handler import retry +from locust.exception import RPCError +import zmq.error as zmqerr +import msgpack.exceptions as msgerr class BaseSocket(object): def __init__(self, sock_type): @@ -13,37 +15,56 @@ def __init__(self, sock_type): @retry() def send(self, msg): - self.socket.send(msg.serialize()) + try: + self.socket.send(msg.serialize(), zmq.NOBLOCK) + except zmqerr.ZMQError as e: + raise RPCError("ZMQ sent failure") from e @retry() def send_to_client(self, msg): - self.socket.send_multipart([msg.node_id.encode(), msg.serialize()]) + try: + self.socket.send_multipart([msg.node_id.encode(), msg.serialize()]) + except zmqerr.ZMQError as e: + raise RPCError("ZMQ sent failure") from e - @retry() def recv(self): - data = self.socket.recv() - msg = Message.unserialize(data) + try: + data = self.socket.recv() + msg = Message.unserialize(data) + except msgerr.ExtraData as e: + raise RPCError("ZMQ interrupted message") from e + except zmqerr.ZMQError as e: + raise RPCError("ZMQ network broken") from e return msg - @retry() def recv_from_client(self): - data = self.socket.recv_multipart() - addr = data[0].decode() - msg = Message.unserialize(data[1]) + try: + data = self.socket.recv_multipart() + addr = data[0].decode() + msg = Message.unserialize(data[1]) + except (UnicodeDecodeError, msgerr.ExtraData) as e: + raise RPCError("ZMQ interrupted message") from e + except zmqerr.ZMQError as e: + raise RPCError("ZMQ network broken") from e return addr, msg + def close(self): + self.socket.close() + class Server(BaseSocket): def __init__(self, host, port): BaseSocket.__init__(self, zmq.ROUTER) if port == 0: self.port = self.socket.bind_to_random_port("tcp://%s" % host) else: - self.socket.bind("tcp://%s:%i" % (host, port)) - self.port = port + try: + self.socket.bind("tcp://%s:%i" % (host, port)) + self.port = port + except zmqerr.ZMQError as e: + raise RPCError("Socket bind failure: %s" % (e) ) class Client(BaseSocket): def __init__(self, host, port, identity): BaseSocket.__init__(self, zmq.DEALER) self.socket.setsockopt(zmq.IDENTITY, identity.encode()) self.socket.connect("tcp://%s:%i" % (host, port)) - \ No newline at end of file diff --git a/locust/runners.py b/locust/runners.py index cd20e49c65..72e8cbac29 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -14,6 +14,8 @@ from .rpc import Message, rpc from .stats import RequestStats, setup_distributed_stats_event_listeners +from .exception import RPCError + logger = logging.getLogger(__name__) @@ -22,6 +24,7 @@ CPU_MONITOR_INTERVAL = 5.0 HEARTBEAT_INTERVAL = 1 HEARTBEAT_LIVENESS = 3 +FALLBACK_INTERVAL = 5 class LocustRunner(object): @@ -50,6 +53,7 @@ def on_request_failure(request_type, name, response_time, response_length, excep self.environment.events.request_success.add_listener(on_request_success) self.environment.events.request_failure.add_listener(on_request_failure) + self.connection_broken = False # register listener that resets stats when hatching is complete def on_hatch_complete(user_count): @@ -303,6 +307,8 @@ def __init__(self, *args, master_bind_host, master_bind_port, **kwargs): super().__init__(*args, **kwargs) self.worker_cpu_warning_emitted = False self.target_user_count = None + self.master_bind_host = master_bind_host + self.master_bind_port = master_bind_port class WorkerNodesDict(dict): def get_by_state(self, state): @@ -408,6 +414,9 @@ def quit(self): def heartbeat_worker(self): while True: gevent.sleep(HEARTBEAT_INTERVAL) + if self.connection_broken: + self.reset_connection() + continue for client in self.clients.all: if client.heartbeat < 0 and client.state != STATE_MISSING: logger.info('Worker %s failed to send heartbeat, setting state to missing.' % str(client.id)) @@ -416,9 +425,24 @@ def heartbeat_worker(self): else: client.heartbeat -= 1 + def reset_connection(self): + logger.info("Reset connection to slave") + try: + self.server.close() + self.server = rpc.Server(self.master_bind_host, self.master_bind_port) + except RPCError as e: + logger.error("Temporay failure when resetting connection: %s, will retry later." % ( e ) ) + def client_listener(self): while True: - client_id, msg = self.server.recv_from_client() + try: + client_id, msg = self.server.recv_from_client() + except RPCError as e: + logger.error("RPCError found when receiving from client: %s" % ( e ) ) + self.connection_broken = True + gevent.sleep(FALLBACK_INTERVAL) + continue + self.connection_broken = False msg.node_id = client_id if msg.type == "client_ready": id = msg.node_id @@ -471,7 +495,8 @@ class WorkerLocustRunner(DistributedLocustRunner): def __init__(self, *args, master_host, master_port, **kwargs): super().__init__(*args, **kwargs) self.client_id = socket.gethostname() + "_" + uuid4().hex - + self.master_host = master_host + self.master_port = master_port self.client = rpc.Client(master_host, master_port, self.client_id) self.greenlet.spawn(self.heartbeat).link_exception(callback=self.noop) self.greenlet.spawn(self.worker).link_exception(callback=self.noop) @@ -503,12 +528,28 @@ def on_locust_error(locust_instance, exception, tb): def heartbeat(self): while True: - self.client.send(Message('heartbeat', {'state': self.worker_state, 'current_cpu_usage': self.current_cpu_usage}, self.client_id)) + try: + self.client.send(Message('heartbeat', {'state': self.worker_state, 'current_cpu_usage': self.current_cpu_usage}, self.client_id)) + except RPCError as e: + logger.error("RPCError found when sending heartbeat: %s" % ( e ) ) + self.reset_connection() gevent.sleep(HEARTBEAT_INTERVAL) + def reset_connection(self): + logger.info("Reset connection to master") + try: + self.client.close() + self.client = rpc.Client(self.master_host, self.master_port, self.client_id) + except RPCError as e: + logger.error("Temporary failure when resetting connection: %s, will retry later." % ( e ) ) + def worker(self): while True: - msg = self.client.recv() + try: + msg = self.client.recv() + except RPCError as e: + logger.error("RPCError found when receiving from master: %s" % ( e ) ) + continue if msg.type == "hatch": self.worker_state = STATE_HATCHING self.client.send(Message("hatching", None, self.client_id)) @@ -535,10 +576,8 @@ def stats_reporter(self): while True: try: self._send_stats() - except: - logger.error("Connection lost to master server. Aborting...") - break - + except RPCError as e: + logger.error("Temporary connection lost to master server: %s, will retry later." % (e)) gevent.sleep(WORKER_REPORT_INTERVAL) def _send_stats(self): diff --git a/locust/test/test_runners.py b/locust/test/test_runners.py index fc2c6e9dae..0d7d061f3a 100644 --- a/locust/test/test_runners.py +++ b/locust/test/test_runners.py @@ -9,7 +9,7 @@ from locust.main import create_environment from locust.core import Locust, TaskSet, task from locust.env import Environment -from locust.exception import LocustError, StopLocust +from locust.exception import LocustError, RPCError, StopLocust from locust.rpc import Message from locust.runners import LocustRunner, LocalLocustRunner, MasterLocustRunner, WorkerNode, \ WorkerLocustRunner, STATE_INIT, STATE_HATCHING, STATE_RUNNING, STATE_MISSING @@ -17,6 +17,8 @@ from locust.test.testcases import LocustTestCase from locust.wait_time import between, constant +NETWORK_BROKEN = "network broken" +UNHANDLED_EXCEPTION = "unhandled exception" def mocked_rpc(): class MockedRpcServerClient(object): @@ -30,10 +32,15 @@ def __init__(self, *args, **kwargs): def mocked_send(cls, message): cls.queue.put(message.serialize()) sleep(0) - + def recv(self): results = self.queue.get() - return Message.unserialize(results) + msg = Message.unserialize(results) + if msg.data == NETWORK_BROKEN: + raise RPCError() + if msg.data == UNHANDLED_EXCEPTION: + raise HeyAnException() + return msg def send(self, message): self.outbox.append(message) @@ -44,8 +51,15 @@ def send_to_client(self, message): def recv_from_client(self): results = self.queue.get() msg = Message.unserialize(results) + if msg.data == NETWORK_BROKEN: + raise RPCError() + if msg.data == UNHANDLED_EXCEPTION: + raise HeyAnException() return msg.node_id, msg + def close(self): + raise RPCError() + return MockedRpcServerClient @@ -62,10 +76,13 @@ def __init__(self): self.heartbeat_interval = 1 self.stop_timeout = None self.step_load = True + self.connection_broken = False def reset_stats(self): pass +class HeyAnException(Exception): + pass class TestLocustRunner(LocustTestCase): def assert_locust_class_distribution(self, expected_distribution, classes): @@ -291,6 +308,7 @@ def setUp(self): super(TestMasterRunner, self).setUp() #self._worker_report_event_handlers = [h for h in events.worker_report._handlers] self.environment.options = mocked_options() + class MyTestLocust(Locust): pass @@ -595,9 +613,6 @@ def test_spawn_locusts_in_stepload_mode(self): self.assertEqual(10, num_clients, "Total number of locusts that would have been spawned for second step is not 10") def test_exception_in_task(self): - class HeyAnException(Exception): - pass - class MyLocust(Locust): @task def will_error(self): @@ -619,8 +634,6 @@ def will_error(self): def test_exception_is_catched(self): """ Test that exceptions are stored, and execution continues """ - class HeyAnException(Exception): - pass class MyTaskSet(TaskSet): def __init__(self, *a, **kw): @@ -658,6 +671,19 @@ class MyLocust(Locust): self.assertTrue("HeyAnException" in exception["traceback"]) self.assertEqual(2, exception["count"]) + def test_master_reset_connection(self): + """ Test that connection will be reset when network issues found """ + with mock.patch("locust.rpc.rpc.Server", mocked_rpc()) as server: + master = self.get_runner() + server.mocked_send(Message("client_ready", NETWORK_BROKEN, "fake_client")) + sleep(3) + assert master.connection_broken == True + server.mocked_send(Message("client_ready", None, "fake_client")) + sleep(3) + assert master.connection_broken == False + server.mocked_send(Message("client_ready", UNHANDLED_EXCEPTION, "fake_client")) + sleep(3) + assert master.connection_broken == False class TestWorkerLocustRunner(LocustTestCase): def setUp(self): @@ -777,7 +803,6 @@ def my_task(self): self.assertEqual(9, len(worker.locusts)) worker.quit() - class TestMessageSerializing(unittest.TestCase): def test_message_serialize(self): msg = Message("client_ready", None, "my_id") diff --git a/locust/test/test_zmqrpc.py b/locust/test/test_zmqrpc.py index b1bd45a7e6..96c3f5ac94 100644 --- a/locust/test/test_zmqrpc.py +++ b/locust/test/test_zmqrpc.py @@ -3,7 +3,7 @@ import zmq from locust.rpc import zmqrpc, Message from locust.test.testcases import LocustTestCase - +from locust.exception import RPCError class ZMQRPC_tests(LocustTestCase): def setUp(self): @@ -12,8 +12,8 @@ def setUp(self): self.client = zmqrpc.Client('localhost', self.server.port, 'identity') def tearDown(self): - self.server.socket.close() - self.client.socket.close() + self.server.close() + self.client.close() super(ZMQRPC_tests, self).tearDown() def test_constructor(self): @@ -42,5 +42,15 @@ def test_client_recv(self): def test_client_retry(self): server = zmqrpc.Server('127.0.0.1', 0) server.socket.close() - with self.assertRaises(zmq.error.ZMQError): + with self.assertRaises(RPCError): server.recv_from_client() + + def test_rpc_error(self): + server = zmqrpc.Server('127.0.0.1', 5557) + with self.assertRaises(RPCError): + server = zmqrpc.Server('127.0.0.1', 5557) + server.close() + with self.assertRaises(RPCError): + server.send_to_client(Message('test', 'message', 'identity')) + + \ No newline at end of file