diff --git a/locust/contrib/fasthttp.py b/locust/contrib/fasthttp.py index e80b288fbe..449486f9d4 100644 --- a/locust/contrib/fasthttp.py +++ b/locust/contrib/fasthttp.py @@ -105,17 +105,29 @@ def __init__(self, environment): ) +def insecure_ssl_context_factory(): + context = gevent.ssl.create_default_context() + context.check_hostname = False + context.verify_mode = gevent.ssl.CERT_NONE + return context + + class FastHttpSession(object): auth_header = None - def __init__(self, environment: Environment, base_url: str, **kwargs): + def __init__(self, environment: Environment, base_url: str, insecure=True, **kwargs): self.environment = environment self.base_url = base_url self.cookiejar = CookieJar() + if insecure: + ssl_context_factory = insecure_ssl_context_factory + else: + ssl_context_factory = gevent.ssl.create_default_context self.client = LocustUserAgent( cookiejar=self.cookiejar, - ssl_options={"cert_reqs": gevent.ssl.CERT_NONE}, - **kwargs + ssl_context_factory=ssl_context_factory, + insecure=insecure, + **kwargs, ) # Check for basic authentication diff --git a/locust/test/test_fasthttp.py b/locust/test/test_fasthttp.py index b6ef617eed..2374bf014a 100644 --- a/locust/test/test_fasthttp.py +++ b/locust/test/test_fasthttp.py @@ -1,11 +1,13 @@ import socket import gevent +from tempfile import NamedTemporaryFile from locust.user import task, TaskSet from locust.contrib.fasthttp import FastHttpSession, FastHttpUser from locust.exception import CatchResponseError, InterruptTaskSet, ResponseError from locust.main import is_user_class -from .testcases import WebserverTestCase +from .testcases import WebserverTestCase, LocustTestCase +from .util import create_tls_cert class TestFastHttpSession(WebserverTestCase): @@ -493,3 +495,34 @@ class MyUser(FastHttpUser): r.failure("Manual fail") self.assertEqual(0, self.num_success) self.assertEqual(1, self.num_failures) + + +class TestFastHttpSsl(LocustTestCase): + def setUp(self): + super().setUp() + tls_cert, tls_key = create_tls_cert("127.0.0.1") + self.tls_cert_file = NamedTemporaryFile() + self.tls_key_file = NamedTemporaryFile() + with open(self.tls_cert_file.name, 'w') as f: + f.write(tls_cert.decode()) + with open(self.tls_key_file.name, 'w') as f: + f.write(tls_key.decode()) + + self.web_ui = self.environment.create_web_ui( + "127.0.0.1", 0, + tls_cert=self.tls_cert_file.name, + tls_key=self.tls_key_file.name, + ) + gevent.sleep(0.01) + self.web_port = self.web_ui.server.server_port + + def tearDown(self): + super().tearDown() + self.web_ui.stop() + + def test_ssl_request_insecure(self): + s = FastHttpSession(self.environment, "https://127.0.0.1:%i" % self.web_port, insecure=True) + r = s.get("/") + self.assertEqual(200, r.status_code) + self.assertIn("Locust", r.content.decode("utf-8")) +