Skip to content

Commit

Permalink
Merge pull request #1280 from delulu/ensureconnection
Browse files Browse the repository at this point in the history
ensure the connection between master and slave in heartbeat
  • Loading branch information
cyberw authored Apr 5, 2020
2 parents 2a0a6ef + 0f9070b commit 8685a4b
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 34 deletions.
7 changes: 7 additions & 0 deletions locust/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ class RescheduleTaskImmediately(Exception):
"""
When raised in a Locust task, another locust task will be rescheduled immediately
"""

class RPCError(Exception):
"""
Exception that shows bad or broken network.
When raised from zmqrpc, RPC should be reestablished.
"""
47 changes: 34 additions & 13 deletions locust/rpc/zmqrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import zmq.green as zmq

from .protocol import Message
from locust.util.exception_handler import retry
from locust.exception import RPCError
import zmq.error as zmqerr
import msgpack.exceptions as msgerr

class BaseSocket(object):
def __init__(self, sock_type):
Expand All @@ -13,37 +15,56 @@ def __init__(self, sock_type):

@retry()
def send(self, msg):
self.socket.send(msg.serialize())
try:
self.socket.send(msg.serialize(), zmq.NOBLOCK)
except zmqerr.ZMQError as e:
raise RPCError("ZMQ sent failure") from e

@retry()
def send_to_client(self, msg):
self.socket.send_multipart([msg.node_id.encode(), msg.serialize()])
try:
self.socket.send_multipart([msg.node_id.encode(), msg.serialize()])
except zmqerr.ZMQError as e:
raise RPCError("ZMQ sent failure") from e

@retry()
def recv(self):
data = self.socket.recv()
msg = Message.unserialize(data)
try:
data = self.socket.recv()
msg = Message.unserialize(data)
except msgerr.ExtraData as e:
raise RPCError("ZMQ interrupted message") from e
except zmqerr.ZMQError as e:
raise RPCError("ZMQ network broken") from e
return msg

@retry()
def recv_from_client(self):
data = self.socket.recv_multipart()
addr = data[0].decode()
msg = Message.unserialize(data[1])
try:
data = self.socket.recv_multipart()
addr = data[0].decode()
msg = Message.unserialize(data[1])
except (UnicodeDecodeError, msgerr.ExtraData) as e:
raise RPCError("ZMQ interrupted message") from e
except zmqerr.ZMQError as e:
raise RPCError("ZMQ network broken") from e
return addr, msg

def close(self):
self.socket.close()

class Server(BaseSocket):
def __init__(self, host, port):
BaseSocket.__init__(self, zmq.ROUTER)
if port == 0:
self.port = self.socket.bind_to_random_port("tcp://%s" % host)
else:
self.socket.bind("tcp://%s:%i" % (host, port))
self.port = port
try:
self.socket.bind("tcp://%s:%i" % (host, port))
self.port = port
except zmqerr.ZMQError as e:
raise RPCError("Socket bind failure: %s" % (e) )

class Client(BaseSocket):
def __init__(self, host, port, identity):
BaseSocket.__init__(self, zmq.DEALER)
self.socket.setsockopt(zmq.IDENTITY, identity.encode())
self.socket.connect("tcp://%s:%i" % (host, port))

55 changes: 47 additions & 8 deletions locust/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .rpc import Message, rpc
from .stats import RequestStats, setup_distributed_stats_event_listeners

from .exception import RPCError

logger = logging.getLogger(__name__)


Expand All @@ -22,6 +24,7 @@
CPU_MONITOR_INTERVAL = 5.0
HEARTBEAT_INTERVAL = 1
HEARTBEAT_LIVENESS = 3
FALLBACK_INTERVAL = 5


class LocustRunner(object):
Expand Down Expand Up @@ -50,6 +53,7 @@ def on_request_failure(request_type, name, response_time, response_length, excep

self.environment.events.request_success.add_listener(on_request_success)
self.environment.events.request_failure.add_listener(on_request_failure)
self.connection_broken = False

# register listener that resets stats when hatching is complete
def on_hatch_complete(user_count):
Expand Down Expand Up @@ -303,6 +307,8 @@ def __init__(self, *args, master_bind_host, master_bind_port, **kwargs):
super().__init__(*args, **kwargs)
self.worker_cpu_warning_emitted = False
self.target_user_count = None
self.master_bind_host = master_bind_host
self.master_bind_port = master_bind_port

class WorkerNodesDict(dict):
def get_by_state(self, state):
Expand Down Expand Up @@ -408,6 +414,9 @@ def quit(self):
def heartbeat_worker(self):
while True:
gevent.sleep(HEARTBEAT_INTERVAL)
if self.connection_broken:
self.reset_connection()
continue
for client in self.clients.all:
if client.heartbeat < 0 and client.state != STATE_MISSING:
logger.info('Worker %s failed to send heartbeat, setting state to missing.' % str(client.id))
Expand All @@ -416,9 +425,24 @@ def heartbeat_worker(self):
else:
client.heartbeat -= 1

def reset_connection(self):
logger.info("Reset connection to slave")
try:
self.server.close()
self.server = rpc.Server(self.master_bind_host, self.master_bind_port)
except RPCError as e:
logger.error("Temporay failure when resetting connection: %s, will retry later." % ( e ) )

def client_listener(self):
while True:
client_id, msg = self.server.recv_from_client()
try:
client_id, msg = self.server.recv_from_client()
except RPCError as e:
logger.error("RPCError found when receiving from client: %s" % ( e ) )
self.connection_broken = True
gevent.sleep(FALLBACK_INTERVAL)
continue
self.connection_broken = False
msg.node_id = client_id
if msg.type == "client_ready":
id = msg.node_id
Expand Down Expand Up @@ -471,7 +495,8 @@ class WorkerLocustRunner(DistributedLocustRunner):
def __init__(self, *args, master_host, master_port, **kwargs):
super().__init__(*args, **kwargs)
self.client_id = socket.gethostname() + "_" + uuid4().hex

self.master_host = master_host
self.master_port = master_port
self.client = rpc.Client(master_host, master_port, self.client_id)
self.greenlet.spawn(self.heartbeat).link_exception(callback=self.noop)
self.greenlet.spawn(self.worker).link_exception(callback=self.noop)
Expand Down Expand Up @@ -503,12 +528,28 @@ def on_locust_error(locust_instance, exception, tb):

def heartbeat(self):
while True:
self.client.send(Message('heartbeat', {'state': self.worker_state, 'current_cpu_usage': self.current_cpu_usage}, self.client_id))
try:
self.client.send(Message('heartbeat', {'state': self.worker_state, 'current_cpu_usage': self.current_cpu_usage}, self.client_id))
except RPCError as e:
logger.error("RPCError found when sending heartbeat: %s" % ( e ) )
self.reset_connection()
gevent.sleep(HEARTBEAT_INTERVAL)

def reset_connection(self):
logger.info("Reset connection to master")
try:
self.client.close()
self.client = rpc.Client(self.master_host, self.master_port, self.client_id)
except RPCError as e:
logger.error("Temporary failure when resetting connection: %s, will retry later." % ( e ) )

def worker(self):
while True:
msg = self.client.recv()
try:
msg = self.client.recv()
except RPCError as e:
logger.error("RPCError found when receiving from master: %s" % ( e ) )
continue
if msg.type == "hatch":
self.worker_state = STATE_HATCHING
self.client.send(Message("hatching", None, self.client_id))
Expand All @@ -535,10 +576,8 @@ def stats_reporter(self):
while True:
try:
self._send_stats()
except:
logger.error("Connection lost to master server. Aborting...")
break

except RPCError as e:
logger.error("Temporary connection lost to master server: %s, will retry later." % (e))
gevent.sleep(WORKER_REPORT_INTERVAL)

def _send_stats(self):
Expand Down
43 changes: 34 additions & 9 deletions locust/test/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from locust.main import create_environment
from locust.core import Locust, TaskSet, task
from locust.env import Environment
from locust.exception import LocustError, StopLocust
from locust.exception import LocustError, RPCError, StopLocust
from locust.rpc import Message
from locust.runners import LocustRunner, LocalLocustRunner, MasterLocustRunner, WorkerNode, \
WorkerLocustRunner, STATE_INIT, STATE_HATCHING, STATE_RUNNING, STATE_MISSING
from locust.stats import RequestStats
from locust.test.testcases import LocustTestCase
from locust.wait_time import between, constant

NETWORK_BROKEN = "network broken"
UNHANDLED_EXCEPTION = "unhandled exception"

def mocked_rpc():
class MockedRpcServerClient(object):
Expand All @@ -30,10 +32,15 @@ def __init__(self, *args, **kwargs):
def mocked_send(cls, message):
cls.queue.put(message.serialize())
sleep(0)

def recv(self):
results = self.queue.get()
return Message.unserialize(results)
msg = Message.unserialize(results)
if msg.data == NETWORK_BROKEN:
raise RPCError()
if msg.data == UNHANDLED_EXCEPTION:
raise HeyAnException()
return msg

def send(self, message):
self.outbox.append(message)
Expand All @@ -44,8 +51,15 @@ def send_to_client(self, message):
def recv_from_client(self):
results = self.queue.get()
msg = Message.unserialize(results)
if msg.data == NETWORK_BROKEN:
raise RPCError()
if msg.data == UNHANDLED_EXCEPTION:
raise HeyAnException()
return msg.node_id, msg

def close(self):
raise RPCError()

return MockedRpcServerClient


Expand All @@ -62,10 +76,13 @@ def __init__(self):
self.heartbeat_interval = 1
self.stop_timeout = None
self.step_load = True
self.connection_broken = False

def reset_stats(self):
pass

class HeyAnException(Exception):
pass

class TestLocustRunner(LocustTestCase):
def assert_locust_class_distribution(self, expected_distribution, classes):
Expand Down Expand Up @@ -291,6 +308,7 @@ def setUp(self):
super(TestMasterRunner, self).setUp()
#self._worker_report_event_handlers = [h for h in events.worker_report._handlers]
self.environment.options = mocked_options()

class MyTestLocust(Locust):
pass

Expand Down Expand Up @@ -595,9 +613,6 @@ def test_spawn_locusts_in_stepload_mode(self):
self.assertEqual(10, num_clients, "Total number of locusts that would have been spawned for second step is not 10")

def test_exception_in_task(self):
class HeyAnException(Exception):
pass

class MyLocust(Locust):
@task
def will_error(self):
Expand All @@ -619,8 +634,6 @@ def will_error(self):

def test_exception_is_catched(self):
""" Test that exceptions are stored, and execution continues """
class HeyAnException(Exception):
pass

class MyTaskSet(TaskSet):
def __init__(self, *a, **kw):
Expand Down Expand Up @@ -658,6 +671,19 @@ class MyLocust(Locust):
self.assertTrue("HeyAnException" in exception["traceback"])
self.assertEqual(2, exception["count"])

def test_master_reset_connection(self):
""" Test that connection will be reset when network issues found """
with mock.patch("locust.rpc.rpc.Server", mocked_rpc()) as server:
master = self.get_runner()
server.mocked_send(Message("client_ready", NETWORK_BROKEN, "fake_client"))
sleep(3)
assert master.connection_broken == True
server.mocked_send(Message("client_ready", None, "fake_client"))
sleep(3)
assert master.connection_broken == False
server.mocked_send(Message("client_ready", UNHANDLED_EXCEPTION, "fake_client"))
sleep(3)
assert master.connection_broken == False

class TestWorkerLocustRunner(LocustTestCase):
def setUp(self):
Expand Down Expand Up @@ -777,7 +803,6 @@ def my_task(self):
self.assertEqual(9, len(worker.locusts))
worker.quit()


class TestMessageSerializing(unittest.TestCase):
def test_message_serialize(self):
msg = Message("client_ready", None, "my_id")
Expand Down
18 changes: 14 additions & 4 deletions locust/test/test_zmqrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import zmq
from locust.rpc import zmqrpc, Message
from locust.test.testcases import LocustTestCase

from locust.exception import RPCError

class ZMQRPC_tests(LocustTestCase):
def setUp(self):
Expand All @@ -12,8 +12,8 @@ def setUp(self):
self.client = zmqrpc.Client('localhost', self.server.port, 'identity')

def tearDown(self):
self.server.socket.close()
self.client.socket.close()
self.server.close()
self.client.close()
super(ZMQRPC_tests, self).tearDown()

def test_constructor(self):
Expand Down Expand Up @@ -42,5 +42,15 @@ def test_client_recv(self):
def test_client_retry(self):
server = zmqrpc.Server('127.0.0.1', 0)
server.socket.close()
with self.assertRaises(zmq.error.ZMQError):
with self.assertRaises(RPCError):
server.recv_from_client()

def test_rpc_error(self):
server = zmqrpc.Server('127.0.0.1', 5557)
with self.assertRaises(RPCError):
server = zmqrpc.Server('127.0.0.1', 5557)
server.close()
with self.assertRaises(RPCError):
server.send_to_client(Message('test', 'message', 'identity'))


0 comments on commit 8685a4b

Please sign in to comment.