From a8c0d7d8c588f3980303358298870f2ea394ab93 Mon Sep 17 00:00:00 2001 From: Jonathan McCall Date: Wed, 6 Feb 2019 15:12:17 -0500 Subject: [PATCH] Add heartbeat to detect down slaves (#927) * Replace zmq sockets with one DEALER-ROUTER socket The PUSH and PULL sockets being used caused hatch messages to get routed to slaves that may have become unresponsive or crashed. This change includes the client id in the messages sent out from the master which ensures that hatch messages are going to slaves the are READY or RUNNING. This should also fix the issue #911 where slaves are not receiving the stop message. I think these issues are a result of PUSH-PULL sockets using a round robin approach. * Remove client_id parameter from send_multipart method * Add heartbeat worker to server and client The server checks to see if clients have expired and if they have updates their status to "missing". The client has a worker that will send a heartbeat on a regular interval. The heart also relays the slave state back to the master so that they stay in sync. * Use new clients.all property in heartbeat worker * Fix reporting of stopped state Wait until all slaves are reporting in as ready before stating that the master is stopped. * Fix tests after changing ZMQ sockets to DEALER-ROUTER * Change heartbeat log msg to info so that it does not appear in tests * Add tests for zmqrpc.py * Remove commented imports, add note about sleep * Support str/unicode diff in py2 vs py3 * Ensure failed zmqrpc tests clean up bound sockets * Create throw away variable for identity from from ZMQ message I think this looks better than using msg[1]. * Replace usage of parse_options in tests with mock options Using parse_options during test setup can conflict with test runners like pytest. Essentially it will swallow up the options that are meant to be passed to the test runner and instead treats them as options being passed to the test. * Set coverage concurrency to gevent Coverage breaks with gevent and does not fully report green threads as having been tested. Setting concurrency in .coveragerc will fix the issue. https://bitbucket.org/ned/coveragepy/issues/149/coverage-gevent-looks-broken * Add test that shows master heartbeat worker marks slaves missing * Add assertions to test_zmqrpc.py * Use unittest assertions * Change assertion value to bytes object * Add cmdline options for heartbeat liveness and interval * Add new option heartbeat_liveness to test_runners mock options * Ensure SlaveNode class uses heartbeat_liveness default or passed * Ensure hatch data can be updated for slaves currently hatching * Add test for start hatching accepted slave states Checks that start_hatching sends messages to ready, running, and hatching slaves. * Remove unneeded imports of mock --- .coveragerc | 1 + locust/main.py | 16 +++++++++ locust/rpc/zmqrpc.py | 41 ++++++++++++---------- locust/runners.py | 64 ++++++++++++++++++++++++++-------- locust/test/test_runners.py | 68 ++++++++++++++++++++++++++++++------- locust/test/test_zmqrpc.py | 32 +++++++++++++++++ 6 files changed, 176 insertions(+), 46 deletions(-) create mode 100644 locust/test/test_zmqrpc.py diff --git a/.coveragerc b/.coveragerc index 401ef22d8f..a2cf1c7931 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,7 @@ [run] branch = True source = locust +concurrency = gevent [report] exclude_lines = diff --git a/locust/main.py b/locust/main.py index 1aea2dddd6..f9a02c91d5 100644 --- a/locust/main.py +++ b/locust/main.py @@ -128,6 +128,22 @@ def parse_options(): help="Port that locust master should bind to. Only used when running with --master. Defaults to 5557. Note that Locust will also use this port + 1, so by default the master node will bind to 5557 and 5558." ) + parser.add_option( + '--heartbeat-liveness', + action='store', + type='int', + default=3, + help="set number of seconds before failed heartbeat from slave" + ) + + parser.add_option( + '--heartbeat-interval', + action='store', + type='int', + default=1, + help="set number of seconds delay between slave heartbeats to master" + ) + parser.add_option( '--expect-slaves', action='store', diff --git a/locust/rpc/zmqrpc.py b/locust/rpc/zmqrpc.py index 07583cb0a7..164d3cb262 100644 --- a/locust/rpc/zmqrpc.py +++ b/locust/rpc/zmqrpc.py @@ -4,30 +4,35 @@ class BaseSocket(object): + def __init__(self, sock_type): + context = zmq.Context() + self.socket = context.socket(sock_type) def send(self, msg): - self.sender.send(msg.serialize()) - + self.socket.send(msg.serialize()) + + def send_to_client(self, msg): + self.socket.send_multipart([msg.node_id.encode(), msg.serialize()]) + def recv(self): - data = self.receiver.recv() - return Message.unserialize(data) + data = self.socket.recv() + msg = Message.unserialize(data) + return msg + def recv_from_client(self): + data = self.socket.recv_multipart() + addr = data[0] + msg = Message.unserialize(data[1]) + return addr, msg class Server(BaseSocket): def __init__(self, host, port): - context = zmq.Context() - self.receiver = context.socket(zmq.PULL) - self.receiver.bind("tcp://%s:%i" % (host, port)) - - self.sender = context.socket(zmq.PUSH) - self.sender.bind("tcp://%s:%i" % (host, port+1)) - + BaseSocket.__init__(self, zmq.ROUTER) + self.socket.bind("tcp://%s:%i" % (host, port)) class Client(BaseSocket): - def __init__(self, host, port): - context = zmq.Context() - self.receiver = context.socket(zmq.PULL) - self.receiver.connect("tcp://%s:%i" % (host, port+1)) - - self.sender = context.socket(zmq.PUSH) - self.sender.connect("tcp://%s:%i" % (host, port)) + def __init__(self, host, port, identity): + BaseSocket.__init__(self, zmq.DEALER) + self.socket.setsockopt(zmq.IDENTITY, identity.encode()) + self.socket.connect("tcp://%s:%i" % (host, port)) + \ No newline at end of file diff --git a/locust/runners.py b/locust/runners.py index dd19bb1011..f23d62fb9e 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -23,7 +23,7 @@ # global locust runner singleton locust_runner = None -STATE_INIT, STATE_HATCHING, STATE_RUNNING, STATE_CLEANUP, STATE_STOPPED = ["ready", "hatching", "running", "cleanup", "stopped"] +STATE_INIT, STATE_HATCHING, STATE_RUNNING, STATE_CLEANUP, STATE_STOPPED, STATE_MISSING = ["ready", "hatching", "running", "cleanup", "stopped", "missing"] SLAVE_REPORT_INTERVAL = 3.0 @@ -213,25 +213,32 @@ def __init__(self, locust_classes, options): self.master_port = options.master_port self.master_bind_host = options.master_bind_host self.master_bind_port = options.master_bind_port + self.heartbeat_liveness = options.heartbeat_liveness + self.heartbeat_interval = options.heartbeat_interval def noop(self, *args, **kwargs): """ Used to link() greenlets to in order to be compatible with gevent 1.0 """ pass class SlaveNode(object): - def __init__(self, id, state=STATE_INIT): + def __init__(self, id, state=STATE_INIT, heartbeat_liveness=3): self.id = id self.state = state self.user_count = 0 + self.heartbeat = heartbeat_liveness class MasterLocustRunner(DistributedLocustRunner): def __init__(self, *args, **kwargs): super(MasterLocustRunner, self).__init__(*args, **kwargs) - + class SlaveNodesDict(dict): def get_by_state(self, state): return [c for c in six.itervalues(self) if c.state == state] + @property + def all(self): + return six.itervalues(self) + @property def ready(self): return self.get_by_state(STATE_INIT) @@ -247,6 +254,7 @@ def running(self): self.clients = SlaveNodesDict() self.server = rpc.Server(self.master_bind_host, self.master_bind_port) self.greenlet = Group() + self.greenlet.spawn(self.heartbeat_worker).link_exception(callback=self.noop) self.greenlet.spawn(self.client_listener).link_exception(callback=self.noop) # listener that gathers info on how many locust users the slaves has spawned @@ -268,7 +276,7 @@ def user_count(self): return sum([c.user_count for c in six.itervalues(self.clients)]) def start_hatching(self, locust_count, hatch_rate): - num_slaves = len(self.clients.ready) + len(self.clients.running) + num_slaves = len(self.clients.ready) + len(self.clients.running) + len(self.clients.hatching) if not num_slaves: logger.warning("You are running in distributed mode but have no slave servers connected. " "Please connect slaves prior to swarming.") @@ -286,7 +294,7 @@ def start_hatching(self, locust_count, hatch_rate): self.exceptions = {} events.master_start_hatching.fire() - for client in six.itervalues(self.clients): + for client in (self.clients.ready + self.clients.running + self.clients.hatching): data = { "hatch_rate":slave_hatch_rate, "num_clients":slave_num_clients, @@ -298,36 +306,49 @@ def start_hatching(self, locust_count, hatch_rate): data["num_clients"] += 1 remaining -= 1 - self.server.send(Message("hatch", data, None)) + self.server.send_to_client(Message("hatch", data, client.id)) self.stats.start_time = time() self.state = STATE_HATCHING def stop(self): - for client in self.clients.hatching + self.clients.running: - self.server.send(Message("stop", None, None)) + for client in self.clients.all: + self.server.send_to_client(Message("stop", None, client.id)) events.master_stop_hatching.fire() def quit(self): - for client in six.itervalues(self.clients): - self.server.send(Message("quit", None, None)) + for client in self.clients.all: + self.server.send_to_client(Message("quit", None, client.id)) self.greenlet.kill(block=True) + def heartbeat_worker(self): + while True: + gevent.sleep(self.heartbeat_interval) + for client in self.clients.all: + if client.heartbeat < 0 and client.state != STATE_MISSING: + logger.info('Slave %s failed to send heartbeat, setting state to missing.' % str(client.id)) + client.state = STATE_MISSING + else: + client.heartbeat -= 1 + def client_listener(self): while True: - msg = self.server.recv() + client_id, msg = self.server.recv_from_client() + msg.node_id = client_id if msg.type == "client_ready": id = msg.node_id - self.clients[id] = SlaveNode(id) + self.clients[id] = SlaveNode(id, heartbeat_liveness=self.heartbeat_liveness) logger.info("Client %r reported as ready. Currently %i clients ready to swarm." % (id, len(self.clients.ready))) ## emit a warning if the slave's clock seem to be out of sync with our clock #if abs(time() - msg.data["time"]) > 5.0: # warnings.warn("The slave node's clock seem to be out of sync. For the statistics to be correct the different locust servers need to have synchronized clocks.") elif msg.type == "client_stopped": del self.clients[msg.node_id] - if len(self.clients.hatching + self.clients.running) == 0: - self.state = STATE_STOPPED logger.info("Removing %s client from running clients" % (msg.node_id)) + elif msg.type == "heartbeat": + if msg.node_id in self.clients: + self.clients[msg.node_id].heartbeat = self.heartbeat_liveness + self.clients[msg.node_id].state = msg.data['state'] elif msg.type == "stats": events.slave_report.fire(client_id=msg.node_id, data=msg.data) elif msg.type == "hatching": @@ -345,6 +366,9 @@ def client_listener(self): elif msg.type == "exception": self.log_exception(msg.node_id, msg.data["msg"], msg.data["traceback"]) + if not self.state == STATE_INIT and all(map(lambda x: x.state == STATE_INIT, self.clients.all)): + self.state = STATE_STOPPED + @property def slave_count(self): return len(self.clients.ready) + len(self.clients.hatching) + len(self.clients.running) @@ -354,16 +378,19 @@ def __init__(self, *args, **kwargs): super(SlaveLocustRunner, self).__init__(*args, **kwargs) self.client_id = socket.gethostname() + "_" + uuid4().hex - self.client = rpc.Client(self.master_host, self.master_port) + self.client = rpc.Client(self.master_host, self.master_port, self.client_id) self.greenlet = Group() + self.greenlet.spawn(self.heartbeat).link_exception(callback=self.noop) self.greenlet.spawn(self.worker).link_exception(callback=self.noop) self.client.send(Message("client_ready", None, self.client_id)) + self.slave_state = STATE_INIT self.greenlet.spawn(self.stats_reporter).link_exception(callback=self.noop) # register listener for when all locust users have hatched, and report it to the master node def on_hatch_complete(user_count): self.client.send(Message("hatch_complete", {"count":user_count}, self.client_id)) + self.slave_state = STATE_RUNNING events.hatch_complete += on_hatch_complete # register listener that adds the current number of spawned locusts to the report that is sent to the master node @@ -382,10 +409,16 @@ def on_locust_error(locust_instance, exception, tb): self.client.send(Message("exception", {"msg" : str(exception), "traceback" : formatted_tb}, self.client_id)) events.locust_error += on_locust_error + def heartbeat(self): + while True: + self.client.send(Message('heartbeat', {'state': self.slave_state}, self.client_id)) + gevent.sleep(self.heartbeat_interval) + def worker(self): while True: msg = self.client.recv() if msg.type == "hatch": + self.slave_state = STATE_HATCHING self.client.send(Message("hatching", None, self.client_id)) job = msg.data self.hatch_rate = job["hatch_rate"] @@ -396,6 +429,7 @@ def worker(self): self.stop() self.client.send(Message("client_stopped", None, self.client_id)) self.client.send(Message("client_ready", None, self.client_id)) + self.slave_state = STATE_INIT elif msg.type == "quit": logger.info("Got quit message from master, shutting down...") self.stop() diff --git a/locust/test/test_runners.py b/locust/test/test_runners.py index f700833e6a..8a51c75c2c 100644 --- a/locust/test/test_runners.py +++ b/locust/test/test_runners.py @@ -8,13 +8,11 @@ from locust import events from locust.core import Locust, TaskSet, task from locust.exception import LocustError -from locust.main import parse_options from locust.rpc import Message -from locust.runners import LocalLocustRunner, MasterLocustRunner +from locust.runners import LocalLocustRunner, MasterLocustRunner, SlaveNode, STATE_INIT, STATE_HATCHING, STATE_RUNNING, STATE_MISSING from locust.stats import global_stats, RequestStats from locust.test.testcases import LocustTestCase - def mocked_rpc_server(): class MockedRpcServer(object): queue = Queue() @@ -35,21 +33,37 @@ def recv(self): def send(self, message): self.outbox.append(message.serialize()) + def send_to_client(self, message): + self.outbox.append([message.node_id, message.serialize()]) + + def recv_from_client(self): + results = self.queue.get() + msg = Message.unserialize(results) + return msg.node_id, msg + return MockedRpcServer +class mocked_options(object): + def __init__(self): + self.hatch_rate = 5 + self.num_clients = 5 + self.host = '/' + self.master_host = 'localhost' + self.master_port = 5557 + self.master_bind_host = '*' + self.master_bind_port = 5557 + self.heartbeat_liveness = 3 + self.heartbeat_interval = 0.01 + + def reset_stats(self): + pass class TestMasterRunner(LocustTestCase): def setUp(self): global_stats.reset_all() self._slave_report_event_handlers = [h for h in events.slave_report._handlers] + self.options = mocked_options() - parser, _, _ = parse_options() - args = [ - "--clients", "10", - "--hatch-rate", "10" - ] - opts, _ = parser.parse_args(args) - self.options = opts def tearDown(self): events.slave_report._handlers = self._slave_report_event_handlers @@ -90,7 +104,18 @@ class MyTestLocust(Locust): server.mocked_send(Message("stats", data, "fake_client")) s = master.stats.get("/", "GET") self.assertEqual(700, s.median_response_time) - + + def test_master_marks_downed_slaves_as_missing(self): + class MyTestLocust(Locust): + pass + + with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server: + master = MasterLocustRunner(MyTestLocust, self.options) + server.mocked_send(Message("client_ready", None, "fake_client")) + sleep(0.1) + # print(master.clients['fake_client'].__dict__) + assert master.clients['fake_client'].state == STATE_MISSING + def test_master_total_stats(self): class MyTestLocust(Locust): pass @@ -166,6 +191,23 @@ class MyTestLocust(Locust): self.assertEqual(30, master.stats.total.get_current_response_time_percentile(0.5)) self.assertEqual(3000, master.stats.total.get_current_response_time_percentile(0.95)) + def test_sends_hatch_data_to_ready_running_hatching_slaves(self): + '''Sends hatch job to running, ready, or hatching slaves''' + class MyTestLocust(Locust): + pass + + with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server: + master = MasterLocustRunner(MyTestLocust, self.options) + master.clients[1] = SlaveNode(1) + master.clients[2] = SlaveNode(2) + master.clients[3] = SlaveNode(3) + master.clients[1].state = STATE_INIT + master.clients[2].state = STATE_HATCHING + master.clients[3].state = STATE_RUNNING + master.start_hatching(5,5) + + self.assertEqual(3, len(server.outbox)) + def test_spawn_zero_locusts(self): class MyTaskSet(TaskSet): @task @@ -207,7 +249,7 @@ class MyTestLocust(Locust): self.assertEqual(5, len(server.outbox)) num_clients = 0 - for msg in server.outbox: + for _, msg in server.outbox: num_clients += Message.unserialize(msg).data["num_clients"] self.assertEqual(7, num_clients, "Total number of locusts that would have been spawned is not 7") @@ -225,7 +267,7 @@ class MyTestLocust(Locust): self.assertEqual(5, len(server.outbox)) num_clients = 0 - for msg in server.outbox: + for _, msg in server.outbox: num_clients += Message.unserialize(msg).data["num_clients"] self.assertEqual(2, num_clients, "Total number of locusts that would have been spawned is not 2") diff --git a/locust/test/test_zmqrpc.py b/locust/test/test_zmqrpc.py new file mode 100644 index 0000000000..de077e70e5 --- /dev/null +++ b/locust/test/test_zmqrpc.py @@ -0,0 +1,32 @@ +import unittest +from time import sleep +import zmq +from locust.rpc import zmqrpc, Message + +PORT = 5557 + +class ZMQRPC_tests(unittest.TestCase): + def setUp(self): + self.server = zmqrpc.Server('*', PORT) + self.client = zmqrpc.Client('localhost', PORT, 'identity') + + def tearDown(self): + self.server.socket.close() + self.client.socket.close() + + def test_client_send(self): + self.client.send(Message('test', 'message', 'identity')) + addr, msg = self.server.recv_from_client() + self.assertEqual(addr, b'identity') + self.assertEqual(msg.type, 'test') + self.assertEqual(msg.data, 'message') + + def test_client_recv(self): + sleep(0.01) + # We have to wait for the client to finish connecting + # before sending a msg to it. + self.server.send_to_client(Message('test', 'message', 'identity')) + msg = self.client.recv() + self.assertEqual(msg.type, 'test') + self.assertEqual(msg.data, 'message') + self.assertEqual(msg.node_id, 'identity')