From 44d72431e38c20102c5d2162e5a4d3d6e9618835 Mon Sep 17 00:00:00 2001 From: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> Date: Tue, 23 Nov 2021 15:02:00 +0530 Subject: [PATCH] Async `get_events`, `handle_event`, `handle_readables`, `handle_writables` (#769) * Asynchronous `handle_event` and `LocalExecutor` thread * Bail out on first task completion * mypy * Add `helper/benchmark.sh` and fix threaded which must now use asyncio (reduced performance of threaded) * Print open file diff from `benchmark.sh` * Add `--local-executor` flag, disabled by default for now until tests are updated * Async `handle_readables` and `handle_writables` for `HttpProtocolHandlerPlugin` interface (doesnt impact proxy/web plugins for now) * Async `get_events` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address tests after async changes * mypy and flake8 * spelldoc * `check.py` and trailing comma * Rename to `_assertions.py` * Add missing `pytest-mock` and `pytest-asyncio` deps * Add `pytest-mock` to `pylint` deps * Correct use of `parameterize` and add `PT007` to flake8 ignores * Fix mypy hints broken for `< Python3.9` * Remove usage of `asynccontextmanager` which is not available for all Python versions that `proxy.py` supports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix for pre-python-3.9 versions * `AsyncTask` apis `set_name` and `get_name` are not available on all supported versions * Install setuptools via `lib-dep` until we recommend editable install * Deprecate support for `Python 3.6` * Use recommendation suggested here https://github.com/abhinavsingh/proxy.py/pull/769\#discussion_r753840929 * Address recommendation here https://github.com/abhinavsingh/proxy.py/pull/769\#discussion_r753841906 * Make `Threadless` agnostic of `multiprocessing.Process` * Acceptors must dispatch to local executor in non-blocking fashion * No daemon for executor processes and fix shutdown logic * Only return fds from `_selected_events` not all events data * Refactor logic * Prefix private methods with `_` * `work_queue` and not `client_queue` * Turn `Threadless` into an abstract executor. Introduce `RemoteExecutor` * Make `LocalExecutor` agnostic of `threading.Thread` * `LocalExecutor` now implements `Threadless` * `get_events` and `get_descriptors` now must return int and not sock. `Threadless` now avoids repeated register/unregister and instead make use of `selectors.modify` * Fix `main` tests * Apply suggestions from code review Co-authored-by: Sviatoslav Sydorenko * Apply code review recommendations manually * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert back `Any` and use `addr or None` * Address `flake8` * Update tests to use `fileno` * Fix doc build * Fix doc spell, use tear down and not teardown * Doc updates * Add back support for `Python 3.6` * Acceptors dont need loop initialization * On Python 3.6 `asyncio.new_event_loop()` is necessary * Make doc happy * `--threaded` needs a new event loop for 3.7 too * Always use `asyncio.new_event_loop()` for threaded mode Added e2e integration tests (subprocess & curl) for all modes. * Lint fixes Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sviatoslav Sydorenko --- .vscode/settings.json | 1 + Makefile | 11 +- README.md | 102 +++-- check.py | 24 +- docs/conf.py | 14 +- examples/web_scraper.py | 5 +- helper/benchmark.sh | 77 ++++ helper/monitor_open_files.sh | 2 +- proxy/common/_compat.py | 18 +- proxy/common/backports.py | 82 ++++ proxy/common/constants.py | 1 + proxy/common/flag.py | 10 +- proxy/common/utils.py | 69 --- proxy/core/acceptor/__init__.py | 4 + proxy/core/acceptor/acceptor.py | 118 +++-- proxy/core/acceptor/executors.py | 33 +- proxy/core/acceptor/local.py | 56 +++ proxy/core/acceptor/remote.py | 62 +++ proxy/core/acceptor/threadless.py | 354 +++++++++------ proxy/core/acceptor/work.py | 5 +- proxy/core/base/tcp_server.py | 38 +- proxy/core/base/tcp_tunnel.py | 21 +- proxy/core/base/tcp_upstream.py | 26 +- proxy/http/handler.py | 167 ++++--- proxy/http/plugin.py | 8 +- proxy/http/proxy/plugin.py | 5 +- proxy/http/proxy/server.py | 29 +- proxy/http/server/plugin.py | 5 +- proxy/http/server/web.py | 14 +- proxy/http/websocket/client.py | 3 +- proxy/proxy.py | 40 +- requirements-testing.txt | 2 + .../exceptions/test_http_proxy_auth_failed.py | 79 ++-- tests/http/test_http_proxy.py | 60 ++- .../http/test_http_proxy_tls_interception.py | 68 ++- tests/http/test_protocol_handler.py | 308 +++++++------ tests/http/test_web_server.py | 417 ++++++++++-------- tests/integration/test_integration.py | 47 +- tests/plugin/test_http_proxy_plugins.py | 184 ++++---- ...ttp_proxy_plugins_with_tls_interception.py | 166 ++++--- tests/test_assertions.py | 26 ++ tests/test_main.py | 8 +- tox.ini | 1 + 43 files changed, 1675 insertions(+), 1095 deletions(-) create mode 100755 helper/benchmark.sh create mode 100644 proxy/common/backports.py create mode 100644 proxy/core/acceptor/local.py create mode 100644 proxy/core/acceptor/remote.py create mode 100644 tests/test_assertions.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 05ba4c5890..e956ff2ea9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -39,6 +39,7 @@ "python.linting.pylintEnabled": true, "python.linting.pylintArgs": ["--generate-members"], "python.linting.flake8Enabled": true, + "python.linting.flake8Args": ["--config", ".flake8"], "python.linting.mypyEnabled": true, "python.formatting.provider": "autopep8", "autoDocstring.docstringFormat": "sphinx" diff --git a/Makefile b/Makefile index f04ce9e26f..de614a3304 100644 --- a/Makefile +++ b/Makefile @@ -27,9 +27,7 @@ endif .PHONY: container container-run container-release .PHONY: devtools dashboard dashboard-clean -all: - echo $(IMAGE_TAG) - # lib-test +all: lib-test https-certificates: # Generate server key @@ -94,7 +92,8 @@ lib-dep: -r requirements.txt \ -r requirements-testing.txt \ -r requirements-release.txt \ - -r requirements-tunnel.txt \ + -r requirements-tunnel.txt && \ + pip install "setuptools>=42" lib-lint: python -m tox -e lint @@ -128,6 +127,7 @@ lib-coverage: $(OPEN) htmlcov/index.html lib-profile: + ulimit -n 65536 && \ sudo py-spy record \ -o profile.svg \ -t -F -s -- \ @@ -137,6 +137,9 @@ lib-profile: --disable-http-proxy \ --enable-web-server \ --plugin proxy.plugin.WebServerPlugin \ + --local-executor \ + --backlog 65536 \ + --open-file-limit 65536 --log-file /dev/null devtools: diff --git a/README.md b/README.md index 6da3423661..abcb665e31 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ - [Setup Local Environment](#setup-local-environment) - [Setup Git Hooks](#setup-git-hooks) - [Sending a Pull Request](#sending-a-pull-request) +- [Benchmarks](#benchmarks) - [Flags](#flags) - [Changelog](#changelog) - [v2.x](#v2x) @@ -126,36 +127,56 @@ ```console # On Macbook Pro 2019 / 2.4 GHz 8-Core Intel Core i9 / 32 GB RAM - ❯ hey -n 10000 -c 100 http://localhost:8899/http-route-example - - Summary: - Total: 0.3248 secs - Slowest: 0.1007 secs - Fastest: 0.0002 secs - Average: 0.0028 secs - Requests/sec: 30784.7958 - - Total data: 190000 bytes - Size/request: 19 bytes - - Response time histogram: - 0.000 [1] | - 0.010 [9533] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ - 0.020 [384] |■■ - - Latency distribution: - 10% in 0.0004 secs - 25% in 0.0007 secs - 50% in 0.0013 secs - 75% in 0.0029 secs - 90% in 0.0057 secs - 95% in 0.0097 secs - 99% in 0.0185 secs - - Status code distribution: - [200] 10000 responses + ❯ ./helper/benchmark.sh + CONCURRENCY: 100 workers, TOTAL REQUESTS: 100000 req, QPS: 5000 req/sec, TIMEOUT: 1 sec + + Summary: + Total: 3.1560 secs + Slowest: 0.0375 secs + Fastest: 0.0006 secs + Average: 0.0031 secs + Requests/sec: 31685.9140 + + Total data: 1900000 bytes + Size/request: 19 bytes + + Response time histogram: + 0.001 [1] | + 0.004 [91680] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ + 0.008 [7929] |■■■ + 0.012 [263] | + 0.015 [29] | + 0.019 [8] | + 0.023 [23] | + 0.026 [15] | + 0.030 [27] | + 0.034 [16] | + 0.037 [9] | + + + Latency distribution: + 10% in 0.0022 secs + 25% in 0.0025 secs + 50% in 0.0029 secs + 75% in 0.0034 secs + 90% in 0.0041 secs + 95% in 0.0048 secs + 99% in 0.0066 secs + + Details (average, fastest, slowest): + DNS+dialup: 0.0000 secs, 0.0006 secs, 0.0375 secs + DNS-lookup: 0.0000 secs, 0.0000 secs, 0.0000 secs + req write: 0.0000 secs, 0.0000 secs, 0.0046 secs + resp wait: 0.0030 secs, 0.0006 secs, 0.0320 secs + resp read: 0.0000 secs, 0.0000 secs, 0.0029 secs + + Status code distribution: + [200] 100000 responses ``` + PS: `proxy.py` and benchmark tools are running on the same machine during the above load test. + Checkout the repo and try it for yourself. See [Benchmarks](#benchmarks) for more details. + - Lightweight - Uses only `~5-20MB` RAM - No external dependency other than standard Python library @@ -1977,13 +1998,21 @@ Every pull request is tested using GitHub actions. See [GitHub workflow](https://github.com/abhinavsingh/proxy.py/tree/develop/.github/workflows) for list of tests. +# Benchmarks + +Simply run the following command from repo root to start benchmark + +```console +❯ ./helper/benchmark.sh +``` + # Flags ```console ❯ proxy -h usage: -m [-h] [--enable-events] [--enable-conn-pool] [--threadless] - [--threaded] [--num-workers NUM_WORKERS] [--backlog BACKLOG] - [--hostname HOSTNAME] [--port PORT] + [--threaded] [--num-workers NUM_WORKERS] [--local-executor] + [--backlog BACKLOG] [--hostname HOSTNAME] [--port PORT] [--unix-socket-path UNIX_SOCKET_PATH] [--num-acceptors NUM_ACCEPTORS] [--version] [--log-level LOG_LEVEL] [--log-file LOG_FILE] [--log-format LOG_FORMAT] @@ -2009,7 +2038,7 @@ usage: -m [-h] [--enable-events] [--enable-conn-pool] [--threadless] [--filtered-url-regex-config FILTERED_URL_REGEX_CONFIG] [--cloudflare-dns-mode CLOUDFLARE_DNS_MODE] -proxy.py v2.3.2 +proxy.py v2.3.2.dev193+g87ff921.d20211121 options: -h, --help show this help message and exit @@ -2026,6 +2055,13 @@ options: handle each client connection. --num-workers NUM_WORKERS Defaults to number of CPU cores. + --local-executor Default: False. Disabled by default. When enabled + acceptors will make use of local (same process) + executor instead of distributing load across remote + (other process) executors. Enable this option to + achieve CPU affinity between acceptors and executors, + instead of using underlying OS kernel scheduling + algorithm. --backlog BACKLOG Default: 100. Maximum number of pending connections to proxy server --hostname HOSTNAME Default: ::1. Server IP address. @@ -2155,6 +2191,10 @@ https://github.com/abhinavsingh/proxy.py/issues/new # Changelog +## v2.4.0 + +- No longer support `Python 3.6` due to `asyncio.run` usage in the core. + ## v2.x - No longer ~~a single file module~~. diff --git a/check.py b/check.py index 71bb51cf63..038758a30f 100644 --- a/check.py +++ b/check.py @@ -53,16 +53,18 @@ sys.exit(1) # Update README.md flags section to match current library --help output -# lib_help = subprocess.check_output( -# ['python', '-m', 'proxy', '-h'] -# ) -# with open('README.md', 'rb+') as f: -# c = f.read() -# pre_flags, post_flags = c.split(b'# Flags') -# help_text, post_changelog = post_flags.split(b'# Changelog') -# f.seek(0) -# f.write(pre_flags + b'# Flags\n\n```console\n\xe2\x9d\xaf proxy -h\n' + lib_help + b'```' + -# b'\n# Changelog' + post_changelog) +lib_help = subprocess.check_output( + ['python', '-m', 'proxy', '-h'], +) +with open('README.md', 'rb+') as f: + c = f.read() + pre_flags, post_flags = c.split(b'# Flags') + help_text, post_changelog = post_flags.split(b'# Changelog') + f.seek(0) + f.write( + pre_flags + b'# Flags\n\n```console\n\xe2\x9d\xaf proxy -h\n' + lib_help + b'```' + + b'\n\n# Changelog' + post_changelog, + ) # Version is also hardcoded in README.md flags section readme_version_cmd = 'cat README.md | grep "proxy.py v" | tail -2 | head -1 | cut -d " " -f 2 | cut -c2-' @@ -72,7 +74,7 @@ # Doesn't contain "v" prefix readme_version = readme_version_output.decode().strip() -if readme_version != lib_version[1:].split('-')[0]: +if readme_version != lib_version: print( 'Version mismatch found. {0} (readme) vs {1} (lib).'.format( readme_version, lib_version, diff --git a/docs/conf.py b/docs/conf.py index 20e36ac5ca..68fb0974c3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -241,11 +241,11 @@ nitpicky = True _any_role = 'any' +_py_obj_role = 'py:obj' _py_class_role = 'py:class' nitpick_ignore = [ (_any_role, ''), (_any_role, '__init__'), - (_any_role, '--threadless'), (_any_role, 'Client'), (_any_role, 'event_queue'), (_any_role, 'fd_queue'), @@ -256,8 +256,10 @@ (_any_role, 'HttpParser.state'), (_any_role, 'HttpProtocolHandler'), (_any_role, 'multiprocessing.Manager'), - (_any_role, 'work_klass'), (_any_role, 'proxy.core.base.tcp_upstream.TcpUpstreamConnectionHandler'), + (_any_role, 'work_klass'), + (_py_class_role, '_asyncio.Task'), + (_py_class_role, 'asyncio.events.AbstractEventLoop'), (_py_class_role, 'CacheStore'), (_py_class_role, 'HttpParser'), (_py_class_role, 'HttpProtocolHandlerPlugin'), @@ -268,11 +270,17 @@ (_py_class_role, 'paramiko.channel.Channel'), (_py_class_role, 'proxy.http.parser.parser.T'), (_py_class_role, 'proxy.plugin.cache.store.base.CacheStore'), + (_py_class_role, 'proxy.core.pool.AcceptorPool'), + (_py_class_role, 'proxy.core.executors.ThreadlessPool'), + (_py_class_role, 'proxy.core.acceptor.threadless.T'), + (_py_class_role, 'queue.Queue[Any]'), (_py_class_role, 'TcpClientConnection'), (_py_class_role, 'TcpServerConnection'), (_py_class_role, 'unittest.case.TestCase'), (_py_class_role, 'unittest.result.TestResult'), (_py_class_role, 'UUID'), - (_py_class_role, 'WebsocketFrame'), (_py_class_role, 'Url'), + (_py_class_role, 'WebsocketFrame'), + (_py_class_role, 'Work'), + (_py_obj_role, 'proxy.core.acceptor.threadless.T'), ] diff --git a/examples/web_scraper.py b/examples/web_scraper.py index 0cd638ade7..4b925876c5 100644 --- a/examples/web_scraper.py +++ b/examples/web_scraper.py @@ -9,7 +9,6 @@ :license: BSD, see LICENSE for more details. """ import time -import socket from typing import Dict @@ -40,11 +39,11 @@ class WebScraper(Work): only PUBSUB protocol. """ - def get_events(self) -> Dict[socket.socket, int]: + async def get_events(self) -> Dict[int, int]: """Return sockets and events (read or write) that we are interested in.""" return {} - def handle_events( + async def handle_events( self, readables: Readables, writables: Writables, diff --git a/helper/benchmark.sh b/helper/benchmark.sh new file mode 100755 index 0000000000..653605c43c --- /dev/null +++ b/helper/benchmark.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# +# proxy.py +# ~~~~~~~~ +# ⚡⚡⚡ Fast, Lightweight, Programmable, TLS interception capable +# proxy server for Application debugging, testing and development. +# +# :copyright: (c) 2013-present by Abhinav Singh and contributors. +# :license: BSD, see LICENSE for more details. +# +usage() { + echo "Usage: ./helper/benchmark.sh" + echo "You must run this script from proxy.py repo root." +} + +DIRNAME=$(dirname "$0") +if [ "$DIRNAME" != "./helper" ]; then + usage + exit 1 +fi + +BASENAME=$(basename "$0") +if [ "$BASENAME" != "benchmark.sh" ]; then + usage + exit 1 +fi + +PWD=$(pwd) +if [ $(basename $PWD) != "proxy.py" ]; then + usage + exit 1 +fi + +TIMEOUT=1 +QPS=20000 +CONCURRENCY=100 +TOTAL_REQUESTS=100000 +OPEN_FILE_LIMIT=65536 +BACKLOG=OPEN_FILE_LIMIT +PID_FILE=/tmp/proxy.pid + +ulimit -n $OPEN_FILE_LIMIT + +# time python -m \ +# proxy \ +# --enable-web-server \ +# --plugin proxy.plugin.WebServerPlugin \ +# --backlog $BACKLOG \ +# --open-file-limit $OPEN_FILE_LIMIT \ +# --pid-file $PID_FILE \ +# --log-file /dev/null + +PID=$(cat $PID_FILE) +if [[ -z "$PID" ]]; then + echo "Either pid file doesn't exist or no pid found in the pid file" + exit 1 +fi +ADDR=$(lsof -Pan -p $PID -i | grep -v COMMAND | awk '{ print $9 }') + +PRE_RUN_OPEN_FILES=$(./helper/monitor_open_files.sh) + +echo "CONCURRENCY: $CONCURRENCY workers, TOTAL REQUESTS: $TOTAL_REQUESTS req, QPS: $QPS req/sec, TIMEOUT: $TIMEOUT sec" +hey \ + -n $TOTAL_REQUESTS \ + -c $CONCURRENCY \ + -q $QPS \ + -t $TIMEOUT \ + http://$ADDR/http-route-example + +POST_RUN_OPEN_FILES=$(./helper/monitor_open_files.sh) + +echo $output + +echo "Open files diff:" +diff <( echo "$PRE_RUN_OPEN_FILES" ) <( echo "$POST_RUN_OPEN_FILES" ) + +# while true; do netstat -ant | grep .8899 | awk '{print $6}' | sort | uniq -c | sort -n; sleep 1; done diff --git a/helper/monitor_open_files.sh b/helper/monitor_open_files.sh index 7bfa48c631..7a8caa0eb1 100755 --- a/helper/monitor_open_files.sh +++ b/helper/monitor_open_files.sh @@ -1,5 +1,5 @@ #!/bin/bash - +# # proxy.py # ~~~~~~~~ # ⚡⚡⚡ Fast, Lightweight, Programmable, TLS interception capable diff --git a/proxy/common/_compat.py b/proxy/common/_compat.py index eaaddfb9d5..c3ec75e411 100644 --- a/proxy/common/_compat.py +++ b/proxy/common/_compat.py @@ -1,9 +1,19 @@ -"""Compatibility code for using Proxy.py across various versions of Python. +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + Compatibility code for using Proxy.py across various versions of Python. -.. spelling:: + .. spelling:: - compat - py + compat + py """ import platform diff --git a/proxy/common/backports.py b/proxy/common/backports.py new file mode 100644 index 0000000000..4734287073 --- /dev/null +++ b/proxy/common/backports.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time + +from typing import Any + + +class cached_property: + """Decorator for read-only properties evaluated only once within TTL period. + It can be used to create a cached property like this:: + + import random + + # the class containing the property must be a new-style class + class MyClass: + # create property whose value is cached for ten minutes + @cached_property(ttl=600) + def randint(self): + # will only be evaluated every 10 min. at maximum. + return random.randint(0, 100) + + The value is cached in the '_cached_properties' attribute of the object instance that + has the property getter method wrapped by this decorator. The '_cached_properties' + attribute value is a dictionary which has a key for every property of the + object which is wrapped by this decorator. Each entry in the cache is + created only when the property is accessed for the first time and is a + two-element tuple with the last computed property value and the last time + it was updated in seconds since the epoch. + + The default time-to-live (TTL) is 300 seconds (5 minutes). Set the TTL to + zero for the cached value to never expire. + + To expire a cached property value manually just do:: + del instance._cached_properties[] + + Adopted from https://wiki.python.org/moin/PythonDecoratorLibrary#Cached_Properties + © 2011 Christopher Arndt, MIT License. + + NOTE: We need this function only because Python in-built are only available + for 3.8+. Hence, we must get rid of this function once proxy.py no longer + support version older than 3.8. + + .. spelling:: + + backports + getter + Arndt + """ + + def __init__(self, ttl: float = 300.0): + self.ttl = ttl + + def __call__(self, fget: Any, doc: Any = None) -> 'cached_property': + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + self.__module__ = fget.__module__ + return self + + def __get__(self, inst: Any, owner: Any) -> Any: + now = time.time() + try: + value, last_update = inst._cached_properties[self.__name__] + if self.ttl > 0 and now - last_update > self.ttl: # noqa: WPS333 + raise AttributeError + except (KeyError, AttributeError): + value = self.fget(inst) + try: + cache = inst._cached_properties + except AttributeError: + cache, inst._cached_properties = {}, {} + finally: + cache[self.__name__] = (value, now) + return value diff --git a/proxy/common/constants.py b/proxy/common/constants.py index bf5773f3f5..e679a2e8ef 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -100,6 +100,7 @@ def _env_threadless_compliant() -> bool: DEFAULT_STATIC_SERVER_DIR = os.path.join(PROXY_PY_DIR, "public") DEFAULT_MIN_COMPRESSION_LIMIT = 20 # In bytes DEFAULT_THREADLESS = _env_threadless_compliant() +DEFAULT_LOCAL_EXECUTOR = False DEFAULT_TIMEOUT = 10.0 DEFAULT_VERSION = False DEFAULT_HTTP_PORT = 80 diff --git a/proxy/common/flag.py b/proxy/common/flag.py index 1c52438149..d4b7da733d 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -333,7 +333,15 @@ def initialize( ) args.pid_file = cast( Optional[str], opts.get( - 'pid_file', args.pid_file, + 'pid_file', + args.pid_file, + ), + ) + args.local_executor = cast( + bool, + opts.get( + 'local_executor', + args.local_executor, ), ) diff --git a/proxy/common/utils.py b/proxy/common/utils.py index f289fc49b2..b25b680522 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -17,7 +17,6 @@ """ import sys import ssl -import time import socket import logging import functools @@ -291,71 +290,3 @@ def set_open_file_limit(soft_limit: int) -> None: logger.debug( 'Open file soft limit set to %d', soft_limit, ) - - -class cached_property: - """Decorator for read-only properties evaluated only once within TTL period. - It can be used to create a cached property like this:: - - import random - - # the class containing the property must be a new-style class - class MyClass: - # create property whose value is cached for ten minutes - @cached_property(ttl=600) - def randint(self): - # will only be evaluated every 10 min. at maximum. - return random.randint(0, 100) - - The value is cached in the '_cached_properties' attribute of the object instance that - has the property getter method wrapped by this decorator. The '_cached_properties' - attribute value is a dictionary which has a key for every property of the - object which is wrapped by this decorator. Each entry in the cache is - created only when the property is accessed for the first time and is a - two-element tuple with the last computed property value and the last time - it was updated in seconds since the epoch. - - The default time-to-live (TTL) is 300 seconds (5 minutes). Set the TTL to - zero for the cached value to never expire. - - To expire a cached property value manually just do:: - del instance._cached_properties[] - - Adopted from https://wiki.python.org/moin/PythonDecoratorLibrary#Cached_Properties - © 2011 Christopher Arndt, MIT License. - - NOTE: We need this function only because Python in-built are only available - for 3.8+. Hence, we must get rid of this function once proxy.py no longer - support version older than 3.8. - - .. spelling:: - - getter - Arndt - """ - - def __init__(self, ttl: float = 300.0): - self.ttl = ttl - - def __call__(self, fget: Any, doc: Any = None) -> 'cached_property': - self.fget = fget - self.__doc__ = doc or fget.__doc__ - self.__name__ = fget.__name__ - self.__module__ = fget.__module__ - return self - - def __get__(self, inst: Any, owner: Any) -> Any: - now = time.time() - try: - value, last_update = inst._cached_properties[self.__name__] - if self.ttl > 0 and now - last_update > self.ttl: # noqa: WPS333 - raise AttributeError - except (KeyError, AttributeError): - value = self.fget(inst) - try: - cache = inst._cached_properties - except AttributeError: - cache, inst._cached_properties = {}, {} - finally: - cache[self.__name__] = (value, now) - return value diff --git a/proxy/core/acceptor/__init__.py b/proxy/core/acceptor/__init__.py index 73da204d72..577e2022f5 100644 --- a/proxy/core/acceptor/__init__.py +++ b/proxy/core/acceptor/__init__.py @@ -19,6 +19,8 @@ from .pool import AcceptorPool from .work import Work from .threadless import Threadless +from .remote import RemoteExecutor +from .local import LocalExecutor from .executors import ThreadlessPool from .listener import Listener @@ -27,6 +29,8 @@ 'AcceptorPool', 'Work', 'Threadless', + 'RemoteExecutor', + 'LocalExecutor', 'ThreadlessPool', 'Listener', ] diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 164fdba085..3fa011d6ff 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -13,11 +13,12 @@ acceptor pre """ +import queue import socket import logging import argparse -import threading import selectors +import threading import multiprocessing import multiprocessing.synchronize @@ -25,49 +26,48 @@ from multiprocessing.reduction import recv_handle from typing import List, Optional, Tuple +from typing import Any # noqa: W0611 pylint: disable=unused-import -from proxy.core.acceptor.executors import ThreadlessPool +from ...common.flag import flags +from ...common.utils import is_threadless +from ...common.logger import Logger +from ...common.constants import DEFAULT_LOCAL_EXECUTOR from ..event import EventQueue -from ...common.utils import is_threadless -from ...common.logger import Logger +from .local import LocalExecutor +from .executors import ThreadlessPool logger = logging.getLogger(__name__) +flags.add_argument( + '--local-executor', + action='store_true', + default=DEFAULT_LOCAL_EXECUTOR, + help='Default: ' + ('True' if DEFAULT_LOCAL_EXECUTOR else 'False') + '. ' + + 'Disabled by default. When enabled acceptors will make use of ' + + 'local (same process) executor instead of distributing load across ' + + 'remote (other process) executors. Enable this option to achieve CPU affinity between ' + + 'acceptors and executors, instead of using underlying OS kernel scheduling algorithm.', +) + + class Acceptor(multiprocessing.Process): """Work acceptor process. On start-up, `Acceptor` accepts a file descriptor which will be used to - accept new work. File descriptor is accepted over a `fd_queue` which is - closed immediately after receiving the descriptor. + accept new work. File descriptor is accepted over a `fd_queue`. `Acceptor` goes on to listen for new work over the received server socket. By default, `Acceptor` will spawn a new thread to handle each work. - However, when `--threadless` option is enabled, `Acceptor` process - will also pre-spawns a - :class:`~proxy.core.acceptor.threadless.Threadless` process during - start-up. Accepted work is passed to these - :class:`~proxy.core.acceptor.threadless.Threadless` processes. - `Acceptor` process shares accepted work with a - :class:`~proxy.core.acceptor.threadless.Threadless` process over - it's dedicated pipe. - - TODO(abhinavsingh): Open questions:: - - 1. Instead of starting - :class:`~proxy.core.acceptor.threadless.Threadless` process, - can we work with a - :class:`~proxy.core.acceptor.threadless.Threadless` thread? - 2. What are the performance implications of sharing fds between - threads vs processes? - 3. How much performance degradation happens when acceptor and - threadless processes are running on separate CPU cores? - 4. Can we ensure both acceptor and threadless process are pinned to - the same CPU core? - + However, when ``--threadless`` option is enabled without ``--local-executor``, + `Acceptor` process will also pre-spawns a + :class:`~proxy.core.acceptor.threadless.Threadless` process during start-up. + Accepted work is delegated to these :class:`~proxy.core.acceptor.threadless.Threadless` + processes. `Acceptor` process shares accepted work with a + :class:`~proxy.core.acceptor.threadless.Threadless` process over it's dedicated pipe. """ def __init__( @@ -101,8 +101,26 @@ def __init__( # File descriptor used to accept new work # Currently, a socket fd is assumed. self.sock: Optional[socket.socket] = None - # Incremented every time work() is called - self._total: int = 0 + # Internals + self._total: Optional[int] = None + self._local_work_queue: Optional['queue.Queue[Any]'] = None + self._local: Optional[LocalExecutor] = None + self._lthread: Optional[threading.Thread] = None + + def accept(self, events: List[Tuple[selectors.SelectorKey, int]]) -> None: + for _, mask in events: + if mask & selectors.EVENT_READ: + if self.sock is not None: + conn, addr = self.sock.accept() + logging.debug( + 'Accepting new work#{0}'.format(conn.fileno()), + ) + work = (conn, addr or None) + if self.flags.local_executor: + assert self._local_work_queue + self._local_work_queue.put_nowait(work) + else: + self._work(*work) def run_once(self) -> None: if self.selector is not None: @@ -113,12 +131,7 @@ def run_once(self) -> None: try: if self.lock.acquire(block=False): locked = True - for _, mask in events: - if mask & selectors.EVENT_READ: - if self.sock is not None: - conn, addr = self.sock.accept() - addr = None if addr == '' else addr - self._work(conn, addr) + self.accept(events) except BlockingIOError: pass finally: @@ -142,6 +155,8 @@ def run(self) -> None: type=socket.SOCK_STREAM, ) try: + if self.flags.local_executor: + self._start_local() self.selector.register(self.sock, selectors.EVENT_READ) while not self.running.is_set(): self.run_once() @@ -149,10 +164,30 @@ def run(self) -> None: pass finally: self.selector.unregister(self.sock) + if self.flags.local_executor: + self._stop_local() self.sock.close() logger.debug('Acceptor#%d shutdown', self.idd) + def _start_local(self) -> None: + assert self.sock + self._local_work_queue = queue.Queue() + self._local = LocalExecutor( + work_queue=self._local_work_queue, + flags=self.flags, + event_queue=self.event_queue, + ) + self._lthread = threading.Thread(target=self._local.run) + self._lthread.daemon = True + self._lthread.start() + + def _stop_local(self) -> None: + if self._lthread is not None and self._local_work_queue is not None: + self._local_work_queue.put(False) + self._lthread.join() + def _work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None: + self._total = self._total or 0 if is_threadless(self.flags.threadless, self.flags.threaded): # Index of worker to which this work should be dispatched # Use round-robin strategy by default. @@ -173,20 +208,21 @@ def _work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None: ) thread.start() logger.debug( - 'Dispatched work#{0}.{1} to worker#{2}'.format( - self.idd, self._total, index, + 'Dispatched work#{0}.{1}.{2} to worker#{3}'.format( + conn.fileno(), self.idd, self._total, index, ), ) else: _, thread = ThreadlessPool.start_threaded_work( self.flags, - conn, addr, + conn, + addr, event_queue=self.event_queue, publisher_id=self.__class__.__name__, ) logger.debug( - 'Started work#{0}.{1} in thread#{2}'.format( - self.idd, self._total, thread.ident, + 'Started work#{0}.{1}.{2} in thread#{3}'.format( + conn.fileno(), self.idd, self._total, thread.ident, ), ) self._total += 1 diff --git a/proxy/core/acceptor/executors.py b/proxy/core/acceptor/executors.py index 3d40841b96..065e78bded 100644 --- a/proxy/core/acceptor/executors.py +++ b/proxy/core/acceptor/executors.py @@ -24,7 +24,7 @@ from typing import Any, Optional, List, Tuple from .work import Work -from .threadless import Threadless +from .remote import RemoteExecutor from ..connection import TcpClientConnection from ..event import EventQueue, eventNames @@ -76,12 +76,6 @@ class ThreadlessPool: If necessary, start multiple threadless pool with different work classes. - - TODO: We could optimize multiple-work-type scenario - by making Threadless class constructor independent of ``work_klass``. - We could then relay the ``work_klass`` during work delegation. - This will also make ThreadlessPool constructor agnostic - of ``work_klass``. """ def __init__( @@ -96,7 +90,8 @@ def __init__( self.work_pids: List[int] = [] self.work_locks: List[multiprocessing.synchronize.Lock] = [] # List of threadless workers - self._workers: List[Threadless] = [] + self._workers: List[RemoteExecutor] = [] + self._processes: List[multiprocessing.Process] = [] def __enter__(self) -> 'ThreadlessPool': self.setup() @@ -183,24 +178,28 @@ def _start_worker(self, index: int) -> None: self.work_locks.append(multiprocessing.Lock()) pipe = multiprocessing.Pipe() self.work_queues.append(pipe[0]) - w = Threadless( - client_queue=pipe[1], + w = RemoteExecutor( + work_queue=pipe[1], flags=self.flags, event_queue=self.event_queue, ) self._workers.append(w) - w.start() - assert w.pid - self.work_pids.append(w.pid) - logger.debug('Started threadless#%d process#%d', index, w.pid) + p = multiprocessing.Process(target=w.run) + # p.daemon = True + self._processes.append(p) + p.start() + assert p.pid + self.work_pids.append(p.pid) + logger.debug('Started threadless#%d process#%d', index, p.pid) def _shutdown_workers(self) -> None: """Pop a running threadless worker and clean it up.""" for index in range(self.flags.num_workers): self._workers[index].running.set() - for index in range(self.flags.num_workers): - pid = self._workers[index].pid - self._workers[index].join() + for _ in range(self.flags.num_workers): + pid = self.work_pids[-1] + self._processes.pop().join() + self._workers.pop() self.work_pids.pop() self.work_queues.pop().close() logger.debug('Stopped threadless process#%d', pid) diff --git a/proxy/core/acceptor/local.py b/proxy/core/acceptor/local.py new file mode 100644 index 0000000000..632497b541 --- /dev/null +++ b/proxy/core/acceptor/local.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + acceptor + teardown +""" +import queue +import logging +import asyncio +import contextlib + +from typing import Optional +from typing import Any # noqa: W0611 pylint: disable=unused-import + +from .threadless import Threadless + +logger = logging.getLogger(__name__) + + +class LocalExecutor(Threadless['queue.Queue[Any]']): + """A threadless executor implementation which uses a queue to receive new work.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + if self._loop is None: + self._loop = asyncio.new_event_loop() + return self._loop + + def work_queue_fileno(self) -> Optional[int]: + return None + + def receive_from_work_queue(self) -> bool: + with contextlib.suppress(queue.Empty): + work = self.work_queue.get(block=False) + if isinstance(work, bool) and work is False: + return True + assert isinstance(work, tuple) + conn, addr = work + # NOTE: Here we are assuming to receive a connection object + # and not a fileno because we are a LocalExecutor. + fileno = conn.fileno() + self.work_on_tcp_conn(fileno=fileno, addr=addr, conn=conn) + return False diff --git a/proxy/core/acceptor/remote.py b/proxy/core/acceptor/remote.py new file mode 100644 index 0000000000..76f8877d21 --- /dev/null +++ b/proxy/core/acceptor/remote.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + acceptor +""" +import asyncio +import logging + +from typing import Optional, Any + +from multiprocessing import connection +from multiprocessing.reduction import recv_handle + +from .threadless import Threadless + +logger = logging.getLogger(__name__) + + +class RemoteExecutor(Threadless[connection.Connection]): + """A threadless executor implementation which receives work over a connection. + + NOTE: RemoteExecutor uses ``recv_handle`` to accept file descriptors. + + TODO: Refactor and abstract ``recv_handle`` part so that a threaded + remote executor can also accept work over a connection. Currently, + remote executors must be running in a process. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + if self._loop is None: + self._loop = asyncio.get_event_loop_policy().get_event_loop() + return self._loop + + def work_queue_fileno(self) -> Optional[int]: + return self.work_queue.fileno() + + def close_work_queue(self) -> None: + self.work_queue.close() + + def receive_from_work_queue(self) -> bool: + # Acceptor will not send address for + # unix socket domain environments. + addr = None + if not self.flags.unix_socket_path: + addr = self.work_queue.recv() + fileno = recv_handle(self.work_queue) + self.work_on_tcp_conn(fileno=fileno, addr=addr) + return False diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index b0c4246af8..fd7ee98add 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -12,35 +12,36 @@ acceptor """ -import argparse import os +import ssl import socket import logging import asyncio +import argparse import selectors -import contextlib import multiprocessing -from multiprocessing import connection -from multiprocessing.reduction import recv_handle -from typing import Dict, Optional, Tuple, List, Generator, Any +from abc import abstractmethod, ABC +from typing import Dict, Optional, Tuple, List, Set, Generic, TypeVar, Union -from .work import Work +from ...common.logger import Logger +from ...common.types import Readables, Writables +from ...common.constants import DEFAULT_SELECTOR_SELECT_TIMEOUT from ..connection import TcpClientConnection -from ..event import EventQueue, eventNames +from ..event import eventNames, EventQueue -from ...common.logger import Logger -from ...common.types import Readables, Writables -from ...common.constants import DEFAULT_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT +from .work import Work + +T = TypeVar('T') logger = logging.getLogger(__name__) -class Threadless(multiprocessing.Process): - """Work executor process. +class Threadless(ABC, Generic[T]): + """Work executor base class. - Threadless process provides an event loop, which is shared across + Threadless provides an event loop, which is shared across multiple :class:`~proxy.core.acceptor.work.Work` instances to handle work. @@ -57,96 +58,73 @@ class Threadless(multiprocessing.Process): implements :class:`~proxy.core.acceptor.work.Work` protocol. It expects a client connection as work payload and hooks into the threadless event loop to handle the client connection. - """ def __init__( self, - client_queue: connection.Connection, + work_queue: T, flags: argparse.Namespace, event_queue: Optional[EventQueue] = None, ) -> None: super().__init__() - self.client_queue = client_queue + self.work_queue = work_queue self.flags = flags self.event_queue = event_queue self.running = multiprocessing.Event() self.works: Dict[int, Work] = {} self.selector: Optional[selectors.DefaultSelector] = None - self.loop: Optional[asyncio.AbstractEventLoop] = None + # If we remove single quotes for typing hint below, + # runtime exceptions will occur for < Python 3.9. + # + # Ref https://github.com/abhinavsingh/proxy.py/runs/4279055360?check_suite_focus=true + self.unfinished: Set['asyncio.Task[bool]'] = set() + self.registered_events_by_work_ids: Dict[ + # work_id + int, + # fileno, mask + Dict[int, int], + ] = {} + self.wait_timeout: float = DEFAULT_SELECTOR_SELECT_TIMEOUT - @contextlib.contextmanager - def selected_events(self) -> Generator[ - Tuple[Readables, Writables], - None, None, - ]: - assert self.selector is not None - events: Dict[socket.socket, int] = {} - for work in self.works.values(): - worker_events = work.get_events() - events.update(worker_events) - for fd in worker_events: - # Can throw ValueError: Invalid file descriptor: -1 - # - # A guard within Work classes may not help here due to - # asynchronous nature. Hence, threadless will handle - # ValueError exceptions raised by selector.register - # for invalid fd. - self.selector.register(fd, worker_events[fd]) - ev = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) - readables = [] - writables = [] - for key, mask in ev: - if mask & selectors.EVENT_READ: - readables.append(key.fileobj) - if mask & selectors.EVENT_WRITE: - writables.append(key.fileobj) - yield (readables, writables) - for fd in events: - self.selector.unregister(fd) - - async def handle_events( - self, fileno: int, - readables: Readables, - writables: Writables - ) -> bool: - return self.works[fileno].handle_events(readables, writables) - - # TODO: Use correct future typing annotations - async def wait_for_tasks( - self, tasks: Dict[int, Any] - ) -> None: - for work_id in tasks: - # TODO: Resolving one handle_events here can block - # resolution of other tasks. This can happen when handle_events - # is slow. - # - # Instead of sequential await, a better option would be to await on - # list of async handle_events. This will allow all handlers to run - # concurrently without blocking each other. - try: - teardown = await asyncio.wait_for(tasks[work_id], DEFAULT_TIMEOUT) - if teardown: - self.cleanup(work_id) - except asyncio.TimeoutError: - self.cleanup(work_id) - - def fromfd(self, fileno: int) -> socket.socket: - return socket.fromfd( - fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, - type=socket.SOCK_STREAM, - ) + @property + @abstractmethod + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + raise NotImplementedError() + + @abstractmethod + def receive_from_work_queue(self) -> bool: + """Work queue is ready to receive new work. + + Receive it and call ``work_on_tcp_conn``. + + Return True to tear down the loop.""" + raise NotImplementedError() + + @abstractmethod + def work_queue_fileno(self) -> Optional[int]: + """If work queue must be selected before calling + ``receive_from_work_queue`` then implementation must + return work queue fd.""" + raise NotImplementedError() + + def close_work_queue(self) -> None: + """Only called if ``work_queue_fileno`` returns an integer. + If an fd is select-able for work queue, make sure + to close the work queue fd now.""" + pass # pragma: no cover - def accept_client(self) -> None: - # Acceptor will not send address for - # unix socket domain environments. - addr = None - if not self.flags.unix_socket_path: - addr = self.client_queue.recv() - fileno = recv_handle(self.client_queue) + def work_on_tcp_conn( + self, + fileno: int, + addr: Optional[Tuple[str, int]] = None, + conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None, + ) -> None: self.works[fileno] = self.flags.work_klass( - TcpClientConnection(conn=self.fromfd(fileno), addr=addr), + TcpClientConnection( + conn=conn or self._fromfd(fileno), + addr=addr, + ), flags=self.flags, event_queue=self.event_queue, ) @@ -162,70 +140,200 @@ def accept_client(self) -> None: 'Exception occurred during initialization', exc_info=e, ) - self.cleanup(fileno) + self._cleanup(fileno) - def cleanup_inactive(self) -> None: + async def _selected_events(self) -> Tuple[ + Dict[int, Tuple[Readables, Writables]], + bool, + ]: + """For each work, collects events they are interested in. + Calls select for events of interest. """ + assert self.selector is not None + for work_id in self.works: + worker_events = await self.works[work_id].get_events() + # NOTE: Current assumption is that multiple works will not + # be interested in the same fd. Descriptors of interests + # returned by work must be unique. + # + # TODO: Ideally we must diff and unregister socks not + # returned of interest within this _select_events call + # but exists in registered_socks_by_work_ids + for fileno in worker_events: + if work_id not in self.registered_events_by_work_ids: + self.registered_events_by_work_ids[work_id] = {} + mask = worker_events[fileno] + if fileno in self.registered_events_by_work_ids[work_id]: + oldmask = self.registered_events_by_work_ids[work_id][fileno] + if mask != oldmask: + self.selector.modify( + fileno, events=mask, + data=work_id, + ) + self.registered_events_by_work_ids[work_id][fileno] = mask + logger.debug( + 'fd#{0} modified for mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) + else: + # Can throw ValueError: Invalid file descriptor: -1 + # + # A guard within Work classes may not help here due to + # asynchronous nature. Hence, threadless will handle + # ValueError exceptions raised by selector.register + # for invalid fd. + self.selector.register( + fileno, events=mask, + data=work_id, + ) + self.registered_events_by_work_ids[work_id][fileno] = mask + logger.debug( + 'fd#{0} registered for mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) + selected = self.selector.select( + timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT, + ) + # Keys are work_id and values are 2-tuple indicating + # readables & writables that work_id is interested in + # and are ready for IO. + work_by_ids: Dict[int, Tuple[Readables, Writables]] = {} + new_work_available = False + wqfileno = self.work_queue_fileno() + if wqfileno is None: + new_work_available = True + for key, mask in selected: + if wqfileno is not None and key.fileobj == wqfileno: + assert mask & selectors.EVENT_READ + new_work_available = True + continue + if key.data not in work_by_ids: + work_by_ids[key.data] = ([], []) + if mask & selectors.EVENT_READ: + work_by_ids[key.data][0].append(key.fileobj) + if mask & selectors.EVENT_WRITE: + work_by_ids[key.data][1].append(key.fileobj) + return (work_by_ids, new_work_available) + + async def _wait_for_tasks( + self, + pending: Set['asyncio.Task[bool]'], + ) -> None: + finished, self.unfinished = await asyncio.wait( + pending, + timeout=self.wait_timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in finished: + if task.result(): + self._cleanup(task._work_id) # type: ignore + # self.cleanup(int(task.get_name())) + + def _fromfd(self, fileno: int) -> socket.socket: + return socket.fromfd( + fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, + type=socket.SOCK_STREAM, + ) + + # TODO: Use cached property to avoid execution repeatedly + # within a second interval. Note that our selector timeout + # is 0.1 second which can unnecessarily result in cleanup + # checks within a second boundary. + def _cleanup_inactive(self) -> None: inactive_works: List[int] = [] for work_id in self.works: if self.works[work_id].is_inactive(): inactive_works.append(work_id) for work_id in inactive_works: - self.cleanup(work_id) + self._cleanup(work_id) - def cleanup(self, work_id: int) -> None: - # TODO: HttpProtocolHandler.shutdown can call flush which may block + # TODO: HttpProtocolHandler.shutdown can call flush which may block + def _cleanup(self, work_id: int) -> None: + if work_id in self.registered_events_by_work_ids: + assert self.selector + for fileno in self.registered_events_by_work_ids[work_id]: + logger.debug( + 'fd#{0} unregistered by work#{1}'.format( + fileno, work_id, + ), + ) + self.selector.unregister(fileno) + self.registered_events_by_work_ids[work_id].clear() + del self.registered_events_by_work_ids[work_id] self.works[work_id].shutdown() del self.works[work_id] - os.close(work_id) + if self.work_queue_fileno() is not None: + os.close(work_id) - def run_once(self) -> None: + def _create_tasks( + self, + work_by_ids: Dict[int, Tuple[Readables, Writables]], + ) -> Set['asyncio.Task[bool]']: + assert self.loop + tasks: Set['asyncio.Task[bool]'] = set() + for work_id in work_by_ids: + task = self.loop.create_task( + self.works[work_id].handle_events(*work_by_ids[work_id]), + ) + task._work_id = work_id # type: ignore[attr-defined] + # task.set_name(work_id) + tasks.add(task) + return tasks + + async def _run_once(self) -> bool: assert self.loop is not None - with self.selected_events() as (readables, writables): - if len(readables) == 0 and len(writables) == 0: - # Remove and shutdown inactive connections - self.cleanup_inactive() - return - # Note that selector from now on is idle, - # until all the logic below completes. - # - # This is where one process per CPU architecture shines, - # as other threadless processes can continue process work - # within their context. + work_by_ids, new_work_available = await self._selected_events() + # Accept new work if available # + # TODO: We must use a work klass to handle + # client_queue fd itself a.k.a. accept_client + # will become handle_readables. + if new_work_available: + teardown = self.receive_from_work_queue() + if teardown: + return teardown + if len(work_by_ids) == 0: + self._cleanup_inactive() + return False # Invoke Threadless.handle_events - # - # TODO: Only send readable / writables that client originally - # registered. - tasks = {} - for fileno in self.works: - tasks[fileno] = self.loop.create_task( - self.handle_events(fileno, readables, writables), - ) - # Accepted client connection from Acceptor - if self.client_queue in readables: - self.accept_client() - # Wait for Threadless.handle_events to complete - self.loop.run_until_complete(self.wait_for_tasks(tasks)) + self.unfinished.update(self._create_tasks(work_by_ids)) + # logger.debug('Executing {0} works'.format(len(self.unfinished))) + await self._wait_for_tasks(self.unfinished) + # logger.debug( + # 'Done executing works, {0} pending, {1} registered'.format( + # len(self.unfinished), len(self.registered_events_by_work_ids), + # ), + # ) # Remove and shutdown inactive workers - self.cleanup_inactive() + self._cleanup_inactive() + return False def run(self) -> None: Logger.setup( self.flags.log_file, self.flags.log_level, self.flags.log_format, ) + wqfileno = self.work_queue_fileno() try: self.selector = selectors.DefaultSelector() - self.selector.register(self.client_queue, selectors.EVENT_READ) - self.loop = asyncio.get_event_loop_policy().get_event_loop() + if wqfileno is not None: + self.selector.register( + wqfileno, + selectors.EVENT_READ, + data=wqfileno, + ) + assert self.loop while not self.running.is_set(): # logger.debug('Working on {0} works'.format(len(self.works))) - self.run_once() + if self.loop.run_until_complete(self._run_once()): + break except KeyboardInterrupt: pass finally: assert self.selector is not None - self.selector.unregister(self.client_queue) - self.client_queue.close() + if wqfileno is not None: + self.selector.unregister(wqfileno) + self.close_work_queue() assert self.loop is not None self.loop.close() diff --git a/proxy/core/acceptor/work.py b/proxy/core/acceptor/work.py index ea05a8056b..11b5deecc6 100644 --- a/proxy/core/acceptor/work.py +++ b/proxy/core/acceptor/work.py @@ -13,7 +13,6 @@ acceptor """ import argparse -import socket from abc import ABC, abstractmethod from uuid import uuid4, UUID @@ -43,12 +42,12 @@ def __init__( self.work = work @abstractmethod - def get_events(self) -> Dict[socket.socket, int]: + async def get_events(self) -> Dict[int, int]: """Return sockets and events (read or write) that we are interested in.""" return {} # pragma: no cover @abstractmethod - def handle_events( + async def handle_events( self, readables: Readables, writables: Writables, diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index 4db61463cd..ce1d476116 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -12,7 +12,6 @@ tcp """ -import socket import logging import selectors @@ -50,44 +49,41 @@ class BaseTcpServerHandler(Work): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.must_flush_before_shutdown = False - if self.flags.unix_socket_path: # pragma: no cover - logger.debug( - 'Connection accepted from {0}'.format(self.work.address), - ) - else: - logger.debug( - 'Connection accepted from {0}'.format(self.work.address), - ) + logger.debug( + 'Work#%d accepted from %s', + self.work.connection.fileno(), + self.work.address, + ) @abstractmethod def handle_data(self, data: memoryview) -> Optional[bool]: """Optionally return True to close client connection.""" pass # pragma: no cover - def get_events(self) -> Dict[socket.socket, int]: + async def get_events(self) -> Dict[int, int]: events = {} # We always want to read from client # Register for EVENT_READ events if self.must_flush_before_shutdown is False: - events[self.work.connection] = selectors.EVENT_READ + events[self.work.connection.fileno()] = selectors.EVENT_READ # If there is pending buffer for client # also register for EVENT_WRITE events if self.work.has_buffer(): - if self.work.connection in events: - events[self.work.connection] |= selectors.EVENT_WRITE + if self.work.connection.fileno() in events: + events[self.work.connection.fileno()] |= selectors.EVENT_WRITE else: - events[self.work.connection] = selectors.EVENT_WRITE + events[self.work.connection.fileno()] = selectors.EVENT_WRITE return events - def handle_events( + async def handle_events( self, readables: Readables, writables: Writables, ) -> bool: """Return True to shutdown work.""" - teardown = self.handle_writables( + teardown = await self.handle_writables( writables, - ) or self.handle_readables(readables) + ) or await self.handle_readables(readables) if teardown: logger.debug( 'Shutting down client {0} connection'.format( @@ -96,9 +92,9 @@ def handle_events( ) return teardown - def handle_writables(self, writables: Writables) -> bool: + async def handle_writables(self, writables: Writables) -> bool: teardown = False - if self.work.connection in writables and self.work.has_buffer(): + if self.work.connection.fileno() in writables and self.work.has_buffer(): logger.debug( 'Flushing buffer to client {0}'.format(self.work.address), ) @@ -109,9 +105,9 @@ def handle_writables(self, writables: Writables) -> bool: self.must_flush_before_shutdown = False return teardown - def handle_readables(self, readables: Readables) -> bool: + async def handle_readables(self, readables: Readables) -> bool: teardown = False - if self.work.connection in readables: + if self.work.connection.fileno() in readables: data = self.work.recv(self.flags.client_recvbuf_size) if data is None: logger.debug( diff --git a/proxy/core/base/tcp_tunnel.py b/proxy/core/base/tcp_tunnel.py index 9e4b7c156f..3c17ce8e31 100644 --- a/proxy/core/base/tcp_tunnel.py +++ b/proxy/core/base/tcp_tunnel.py @@ -12,7 +12,6 @@ tcp """ -import socket import logging import selectors @@ -65,32 +64,32 @@ def shutdown(self) -> None: self.upstream.close() super().shutdown() - def get_events(self) -> Dict[socket.socket, int]: + async def get_events(self) -> Dict[int, int]: # Get default client events - ev: Dict[socket.socket, int] = super().get_events() + ev: Dict[int, int] = await super().get_events() # Read from server if we are connected if self.upstream and self.upstream._conn is not None: - ev[self.upstream.connection] = selectors.EVENT_READ + ev[self.upstream.connection.fileno()] = selectors.EVENT_READ # If there is pending buffer for server # also register for EVENT_WRITE events if self.upstream and self.upstream.has_buffer(): - if self.upstream.connection in ev: - ev[self.upstream.connection] |= selectors.EVENT_WRITE + if self.upstream.connection.fileno() in ev: + ev[self.upstream.connection.fileno()] |= selectors.EVENT_WRITE else: - ev[self.upstream.connection] = selectors.EVENT_WRITE + ev[self.upstream.connection.fileno()] = selectors.EVENT_WRITE return ev - def handle_events( + async def handle_events( self, readables: Readables, writables: Writables, ) -> bool: # Handle client events - do_shutdown: bool = super().handle_events(readables, writables) + do_shutdown: bool = await super().handle_events(readables, writables) if do_shutdown: return do_shutdown # Handle server events - if self.upstream and self.upstream.connection in readables: + if self.upstream and self.upstream.connection.fileno() in readables: data = self.upstream.recv() if data is None: # Server closed connection @@ -98,7 +97,7 @@ def handle_events( return True # tunnel data to client self.work.queue(data) - if self.upstream and self.upstream.connection in writables: + if self.upstream and self.upstream.connection.fileno() in writables: self.upstream.flush() return False diff --git a/proxy/core/base/tcp_upstream.py b/proxy/core/base/tcp_upstream.py index e04ec5e90d..1de1e4a910 100644 --- a/proxy/core/base/tcp_upstream.py +++ b/proxy/core/base/tcp_upstream.py @@ -8,12 +8,10 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from abc import ABC, abstractmethod - import ssl -import socket import logging +from abc import ABC, abstractmethod from typing import Tuple, List, Optional, Any from ...common.types import Readables, Writables @@ -24,12 +22,13 @@ class TcpUpstreamConnectionHandler(ABC): """:class:`~proxy.core.base.TcpUpstreamConnectionHandler` can - be used to insert an upstream server connection lifecycle within - asynchronous proxy.py lifecycle. + be used to insert an upstream server connection lifecycle. Call `initialize_upstream` to initialize the upstream connection object. Then, directly use ``self.upstream`` object within your class. + See :class:`~proxy.plugin.proxy_pool.ProxyPoolPlugin` for example usage. + .. spelling:: tcp @@ -67,20 +66,25 @@ def handle_upstream_data(self, raw: memoryview) -> None: def initialize_upstream(self, addr: str, port: int) -> None: self.upstream = TcpServerConnection(addr, port) - def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]: + def get_descriptors(self) -> Tuple[List[int], List[int]]: if not self.upstream: return [], [] - return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else [] + return [self.upstream.connection.fileno()], \ + [self.upstream.connection.fileno()] \ + if self.upstream.has_buffer() \ + else [] def read_from_descriptors(self, r: Readables) -> bool: - if self.upstream and self.upstream.connection in r: + if self.upstream and \ + self.upstream.connection.fileno() in r: try: raw = self.upstream.recv(self.server_recvbuf_size) if raw is not None: self.total_size += len(raw) self.handle_upstream_data(raw) else: - return True # Teardown because upstream proxy closed the connection + # Tear down because upstream proxy closed the connection + return True except ssl.SSLWantReadError: logger.info('Upstream SSLWantReadError, will retry') return False @@ -90,7 +94,9 @@ def read_from_descriptors(self, r: Readables) -> bool: return False def write_to_descriptors(self, w: Writables) -> bool: - if self.upstream and self.upstream.connection in w and self.upstream.has_buffer(): + if self.upstream and \ + self.upstream.connection.fileno() in w and \ + self.upstream.has_buffer(): try: self.upstream.flush() except ssl.SSLWantWriteError: diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 0af90ad53a..1551b298f1 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -16,11 +16,11 @@ import time import errno import socket +import asyncio import logging import selectors -import contextlib -from typing import Tuple, List, Union, Optional, Generator, Dict, Any +from typing import Tuple, List, Union, Optional, Dict, Any from .plugin import HttpProtocolHandlerPlugin from .parser import HttpParser, httpParserStates, httpParserTypes @@ -64,11 +64,6 @@ 'data sent or received by the client.', ) -SelectedEventsGeneratorType = Generator[ - Tuple[Readables, Writables], - None, None, -] - class HttpProtocolHandler(BaseTcpServerHandler): """HTTP, HTTPS, HTTP2, WebSockets protocol handler. @@ -124,19 +119,16 @@ def shutdown(self) -> None: # Flush pending buffer in threaded mode only. # For threadless mode, BaseTcpServerHandler implements # the must_flush_before_shutdown logic automagically. - if self.selector: + if self.selector and self.work.has_buffer(): self._flush() - # Invoke plugin.on_client_connection_close for plugin in self.plugins.values(): plugin.on_client_connection_close() - logger.debug( 'Closing client connection %r ' 'at address %s has buffer %s' % (self.work.connection, self.work.address, self.work.has_buffer()), ) - conn = self.work.connection # Unwrap if wrapped before shutdown. if self._encryption_enabled() and \ @@ -151,53 +143,49 @@ def shutdown(self) -> None: logger.debug('Client connection closed') super().shutdown() - def get_events(self) -> Dict[socket.socket, int]: + async def get_events(self) -> Dict[int, int]: # Get default client events - events: Dict[socket.socket, int] = super().get_events() + events: Dict[int, int] = await super().get_events() # HttpProtocolHandlerPlugin.get_descriptors for plugin in self.plugins.values(): plugin_read_desc, plugin_write_desc = plugin.get_descriptors() - for r in plugin_read_desc: - if r not in events: - events[r] = selectors.EVENT_READ + for rfileno in plugin_read_desc: + if rfileno not in events: + events[rfileno] = selectors.EVENT_READ else: - events[r] |= selectors.EVENT_READ - for w in plugin_write_desc: - if w not in events: - events[w] = selectors.EVENT_WRITE + events[rfileno] |= selectors.EVENT_READ + for wfileno in plugin_write_desc: + if wfileno not in events: + events[wfileno] = selectors.EVENT_WRITE else: - events[w] |= selectors.EVENT_WRITE + events[wfileno] |= selectors.EVENT_WRITE return events # We override super().handle_events and never call it - def handle_events( + async def handle_events( self, readables: Readables, writables: Writables, ) -> bool: """Returns True if proxy must tear down.""" # Flush buffer for ready to write sockets - teardown = self.handle_writables(writables) + teardown = await self.handle_writables(writables) if teardown: return True - # Invoke plugin.write_to_descriptors for plugin in self.plugins.values(): - teardown = plugin.write_to_descriptors(writables) + teardown = await plugin.write_to_descriptors(writables) if teardown: return True - # Read from ready to read sockets - teardown = self.handle_readables(readables) + teardown = await self.handle_readables(readables) if teardown: return True - # Invoke plugin.read_from_descriptors for plugin in self.plugins.values(): - teardown = plugin.read_from_descriptors(readables) + teardown = await plugin.read_from_descriptors(readables) if teardown: return True - return False def handle_data(self, data: memoryview) -> Optional[bool]: @@ -208,7 +196,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: try: # HttpProtocolHandlerPlugin.on_client_data - # Can raise HttpProtocolException to teardown the connection + # Can raise HttpProtocolException to tear down the connection for plugin in self.plugins.values(): optional_data = plugin.on_client_data(data) if optional_data is None: @@ -248,8 +236,8 @@ def handle_data(self, data: memoryview) -> Optional[bool]: return True return False - def handle_writables(self, writables: Writables) -> bool: - if self.work.connection in writables and self.work.has_buffer(): + async def handle_writables(self, writables: Writables) -> bool: + if self.work.connection.fileno() in writables and self.work.has_buffer(): logger.debug('Client is ready for writes, flushing buffer') self.last_activity = time.time() @@ -265,7 +253,7 @@ def handle_writables(self, writables: Writables) -> bool: try: # Call super() for client flush - teardown = super().handle_writables(writables) + teardown = await super().handle_writables(writables) if teardown: return True except BrokenPipeError: @@ -278,12 +266,12 @@ def handle_writables(self, writables: Writables) -> bool: return True return False - def handle_readables(self, readables: Readables) -> bool: - if self.work.connection in readables: + async def handle_readables(self, readables: Readables) -> bool: + if self.work.connection.fileno() in readables: logger.debug('Client is ready for reads, reading') self.last_activity = time.time() try: - teardown = super().handle_readables(readables) + teardown = await super().handle_readables(readables) if teardown: return teardown except ssl.SSLWantReadError: # Try again later @@ -305,40 +293,6 @@ def handle_readables(self, readables: Readables) -> bool: return True return False - ## - # run() is here to maintain backward compatibility for threaded mode - ## - - def run(self) -> None: - """run() method is not used when in --threadless mode. - - This is here just to maintain backward compatibility with threaded mode. - """ - try: - self.initialize() - while True: - # Teardown if client buffer is empty and connection is inactive - if self.is_inactive(): - logger.debug( - 'Client buffer is empty and maximum inactivity has reached ' - 'between client and server connection, tearing down...', - ) - break - teardown = self._run_once() - if teardown: - break - except KeyboardInterrupt: # pragma: no cover - pass - except ssl.SSLError as e: - logger.exception('ssl.SSLError', exc_info=e) - except Exception as e: - logger.exception( - 'Exception while handling connection %r' % - self.work.connection, exc_info=e, - ) - finally: - self.shutdown() - ## # Internal methods ## @@ -360,10 +314,12 @@ def _optionally_wrap_socket( conn = wrap_socket(conn, self.flags.keyfile, self.flags.certfile) return conn - @contextlib.contextmanager - def _selected_events(self) -> SelectedEventsGeneratorType: + # FIXME: Returning events is only necessary because we cannot use async context manager + # for < Python 3.8. As a reason, this method is no longer a context manager and caller + # is responsible for unregistering the descriptors. + async def _selected_events(self) -> Tuple[Dict[int, int], Readables, Writables]: assert self.selector - events = self.get_events() + events = await self.get_events() for fd in events: self.selector.register(fd, events[fd]) ev = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) @@ -374,27 +330,18 @@ def _selected_events(self) -> SelectedEventsGeneratorType: readables.append(key.fileobj) if mask & selectors.EVENT_WRITE: writables.append(key.fileobj) - yield (readables, writables) - for fd in events: - self.selector.unregister(fd) - - def _run_once(self) -> bool: - with self._selected_events() as (readables, writables): - teardown = self.handle_events(readables, writables) - if teardown: - return True - return False + return (events, readables, writables) def _flush(self) -> None: assert self.selector - if not self.work.has_buffer(): - return + logger.debug('Flushing pending data') try: self.selector.register( self.work.connection, selectors.EVENT_WRITE, ) while self.work.has_buffer(): + logging.debug('Waiting for client read ready') ev: List[ Tuple[selectors.SelectorKey, int] ] = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) @@ -408,3 +355,51 @@ def _flush(self) -> None: def _connection_inactive_for(self) -> float: return time.time() - self.last_activity + + ## + # run() and _run_once() are here to maintain backward compatibility + # with threaded mode. These methods are only called when running + # in threaded mode. + ## + + def run(self) -> None: + """run() method is not used when in --threadless mode. + + This is here just to maintain backward compatibility with threaded mode. + """ + loop = asyncio.new_event_loop() + try: + self.initialize() + while True: + # Tear down if client buffer is empty and connection is inactive + if self.is_inactive(): + logger.debug( + 'Client buffer is empty and maximum inactivity has reached ' + 'between client and server connection, tearing down...', + ) + break + if loop.run_until_complete(self._run_once()): + break + except KeyboardInterrupt: # pragma: no cover + pass + except ssl.SSLError as e: + logger.exception('ssl.SSLError', exc_info=e) + except Exception as e: + logger.exception( + 'Exception while handling connection %r' % + self.work.connection, exc_info=e, + ) + finally: + self.shutdown() + loop.close() + + async def _run_once(self) -> bool: + events, readables, writables = await self._selected_events() + try: + return await self.handle_events(readables, writables) + finally: + assert self.selector + # TODO: Like Threadless we should not unregister + # work fds repeatedly. + for fd in events: + self.selector.unregister(fd) diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index f5d49ef192..d5510b5c2b 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -71,15 +71,13 @@ def name(self) -> str: return self.__class__.__name__ @abstractmethod - def get_descriptors( - self, - ) -> Tuple[List[socket.socket], List[socket.socket]]: + def get_descriptors(self) -> Tuple[List[int], List[int]]: """Implementations must return a list of descriptions that they wish to read from and write into.""" return [], [] # pragma: no cover @abstractmethod - def write_to_descriptors(self, w: Writables) -> bool: + async def write_to_descriptors(self, w: Writables) -> bool: """Implementations must now write/flush data over the socket. Note that buffer management is in-build into the connection classes. @@ -90,7 +88,7 @@ def write_to_descriptors(self, w: Writables) -> bool: return False # pragma: no cover @abstractmethod - def read_from_descriptors(self, r: Readables) -> bool: + async def read_from_descriptors(self, r: Readables) -> bool: """Implementations must now read data over the socket.""" return False # pragma: no cover diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index 94eed56934..81392d1d63 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -12,7 +12,6 @@ http """ -import socket import argparse from abc import ABC @@ -61,9 +60,7 @@ def name(self) -> str: # Since 3.4.0 # # @abstractmethod - def get_descriptors( - self, - ) -> Tuple[List[socket.socket], List[socket.socket]]: + def get_descriptors(self) -> Tuple[List[int], List[int]]: return [], [] # pragma: no cover # @abstractmethod diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index e7c1e9d34f..b8181d663e 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -177,28 +177,24 @@ def tls_interception_enabled(self) -> bool: self.flags.ca_signing_key_file is not None and \ self.flags.ca_cert_file is not None - def get_descriptors( - self, - ) -> Tuple[List[socket.socket], List[socket.socket]]: + def get_descriptors(self) -> Tuple[List[int], List[int]]: if not self.request.has_host(): return [], [] - - r: List[socket.socket] = [] - w: List[socket.socket] = [] + r: List[int] = [] + w: List[int] = [] if ( self.upstream and not self.upstream.closed and self.upstream.connection ): - r.append(self.upstream.connection) + r.append(self.upstream.connection.fileno()) if ( self.upstream and not self.upstream.closed and self.upstream.has_buffer() and self.upstream.connection ): - w.append(self.upstream.connection) - + w.append(self.upstream.connection.fileno()) # TODO(abhinavsingh): We need to keep a mapping of plugin and # descriptors registered by them, so that within write/read blocks # we can invoke the right plugin callbacks. @@ -206,7 +202,6 @@ def get_descriptors( plugin_read_desc, plugin_write_desc = plugin.get_descriptors() r.extend(plugin_read_desc) w.extend(plugin_write_desc) - return r, w def _close_and_release(self) -> bool: @@ -218,8 +213,8 @@ def _close_and_release(self) -> bool: self.upstream = None return True - def write_to_descriptors(self, w: Writables) -> bool: - if (self.upstream and self.upstream.connection not in w) or not self.upstream: + async def write_to_descriptors(self, w: Writables) -> bool: + if (self.upstream and self.upstream.connection.fileno() not in w) or not self.upstream: # Currently, we just call write/read block of each plugins. It is # plugins responsibility to ignore this callback, if passed descriptors # doesn't contain the descriptor they registered. @@ -230,7 +225,7 @@ def write_to_descriptors(self, w: Writables) -> bool: elif self.request.has_host() and \ self.upstream and not self.upstream.closed and \ self.upstream.has_buffer() and \ - self.upstream.connection in w: + self.upstream.connection.fileno() in w: logger.debug('Server is write ready, flushing buffer') try: self.upstream.flush() @@ -251,11 +246,11 @@ def write_to_descriptors(self, w: Writables) -> bool: return self._close_and_release() return False - def read_from_descriptors(self, r: Readables) -> bool: + async def read_from_descriptors(self, r: Readables) -> bool: if ( self.upstream and not self.upstream.closed and - self.upstream.connection not in r + self.upstream.connection.fileno() not in r ) or not self.upstream: # Currently, we just call write/read block of each plugins. It is # plugins responsibility to ignore this callback, if passed descriptors @@ -267,7 +262,7 @@ def read_from_descriptors(self, r: Readables) -> bool: elif self.request.has_host() \ and self.upstream \ and not self.upstream.closed \ - and self.upstream.connection in r: + and self.upstream.connection.fileno() in r: logger.debug('Server is ready for reads, reading...') try: raw = self.upstream.recv(self.flags.server_recvbuf_size) @@ -445,7 +440,7 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: # self.access_log() return chunk - # Can return None to teardown connection + # Can return None to tear down connection def on_client_data(self, raw: memoryview) -> Optional[memoryview]: if not self.request.has_host(): return raw diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index 11c2e3ec76..c9a8fe2da0 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -12,7 +12,6 @@ http """ -import socket import argparse from uuid import UUID @@ -60,9 +59,7 @@ def name(self) -> str: # Since 3.4.0 # # @abstractmethod - def get_descriptors( - self, - ) -> Tuple[List[socket.socket], List[socket.socket]]: + def get_descriptors(self) -> Tuple[List[int], List[int]]: return [], [] # pragma: no cover # @abstractmethod diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index d9ff773d88..129c088ffb 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -25,7 +25,7 @@ from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER from ...common.constants import DEFAULT_MIN_COMPRESSION_LIMIT, DEFAULT_WEB_ACCESS_LOG_FORMAT from ...common.utils import bytes_, text_, build_http_response, build_websocket_handshake_response -from ...common.utils import cached_property +from ...common.backports import cached_property from ...common.types import Readables, Writables from ...common.flag import flags @@ -235,9 +235,7 @@ def on_request_complete(self) -> Union[socket.socket, bool]: self.client.queue(self.DEFAULT_404_RESPONSE) return True - def get_descriptors( - self, - ) -> Tuple[List[socket.socket], List[socket.socket]]: + def get_descriptors(self) -> Tuple[List[int], List[int]]: r, w = [], [] for plugin in self.plugins.values(): r1, w1 = plugin.get_descriptors() @@ -245,14 +243,14 @@ def get_descriptors( w.extend(w1) return r, w - def write_to_descriptors(self, w: Writables) -> bool: + async def write_to_descriptors(self, w: Writables) -> bool: for plugin in self.plugins.values(): teardown = plugin.write_to_descriptors(w) if teardown: return True return False - def read_from_descriptors(self, r: Readables) -> bool: + async def read_from_descriptors(self, r: Readables) -> bool: for plugin in self.plugins.values(): teardown = plugin.read_from_descriptors(r) if teardown: @@ -266,7 +264,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: remaining = raw.tobytes() frame = WebsocketFrame() while remaining != b'': - # TODO: Teardown if invalid protocol exception + # TODO: Tear down if invalid protocol exception remaining = frame.parse(remaining) if frame.opcode == websocketOpcodes.CONNECTION_CLOSE: logger.warning( @@ -294,7 +292,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: self.route.handle_request(self.pipeline_request) if not self.pipeline_request.is_http_1_1_keep_alive(): logger.error( - 'Pipelined request is not keep-alive, will teardown request...', + 'Pipelined request is not keep-alive, will tear down request...', ) raise HttpProtocolException() self.pipeline_request = None diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py index da29a4a954..f4e25573bb 100644 --- a/proxy/http/websocket/client.py +++ b/proxy/http/websocket/client.py @@ -113,8 +113,7 @@ def run_once(self) -> bool: def run(self) -> None: try: while not self.closed: - teardown = self.run_once() - if teardown: + if self.run_once(): break except KeyboardInterrupt: pass diff --git a/proxy/proxy.py b/proxy/proxy.py index c05c6cf0ac..0f2346f935 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -114,15 +114,20 @@ class Proxy: - """Context manager to control AcceptorPool, ExecutorPool & EventingCore lifecycle. + """Proxy is a context manager to control proxy.py library core. - By default, AcceptorPool is started with + By default, :class:`~proxy.core.pool.AcceptorPool` is started with :class:`~proxy.http.handler.HttpProtocolHandler` work class. By definition, it expects HTTP traffic to flow between clients and server. - Optionally, it also initializes the eventing core, a multi-process safe - pubsub system queue which can be used to build various patterns - for message sharing and/or signaling. + In ``--threadless`` mode and without ``--local-executor``, + a :class:`~proxy.core.executors.ThreadlessPool` is also started. + Executor pool receives newly accepted work by :class:`~proxy.core.acceptor.Acceptor` + and creates an instance of work class for processing the received work. + + Optionally, Proxy class also initializes the EventManager. + A multi-process safe pubsub system which can be used to build various + patterns for message sharing and/or signaling. """ def __init__(self, input_args: Optional[List[str]] = None, **opts: Any) -> None: @@ -162,6 +167,7 @@ def setup(self) -> None: # we are listening upon. This is necessary to preserve # the server port when `--port=0` is used. self.flags.port = self.listener._port + # Setup EventManager if self.flags.enable_events: logger.info('Core Event enabled') self.event_manager = EventManager() @@ -169,17 +175,20 @@ def setup(self) -> None: event_queue = self.event_manager.queue \ if self.event_manager is not None \ else None - self.executors = ThreadlessPool( - flags=self.flags, - event_queue=event_queue, - ) - self.executors.setup() + # Setup remote executors + if not self.flags.local_executor: + self.executors = ThreadlessPool( + flags=self.flags, + event_queue=event_queue, + ) + self.executors.setup() + # Setup acceptors self.acceptors = AcceptorPool( flags=self.flags, listener=self.listener, - executor_queues=self.executors.work_queues, - executor_pids=self.executors.work_pids, - executor_locks=self.executors.work_locks, + executor_queues=self.executors.work_queues if self.executors else [], + executor_pids=self.executors.work_pids if self.executors else [], + executor_locks=self.executors.work_locks if self.executors else [], event_queue=event_queue, ) self.acceptors.setup() @@ -188,8 +197,9 @@ def setup(self) -> None: def shutdown(self) -> None: assert self.acceptors self.acceptors.shutdown() - assert self.executors - self.executors.shutdown() + if not self.flags.local_executor: + assert self.executors + self.executors.shutdown() if self.flags.enable_events: assert self.event_manager is not None self.event_manager.shutdown() diff --git a/requirements-testing.txt b/requirements-testing.txt index 6beae2c4a6..6eae0fb128 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -5,6 +5,8 @@ flake8==4.0.1 pytest==6.2.5 pytest-cov==3.0.0 pytest-xdist == 2.4.0 +pytest-mock==3.6.1 +pytest-asyncio==0.16.0 autopep8==1.6.0 mypy==0.910 py-spy==0.3.10 diff --git a/tests/http/exceptions/test_http_proxy_auth_failed.py b/tests/http/exceptions/test_http_proxy_auth_failed.py index a65fc045ff..9b4feb1bc5 100644 --- a/tests/http/exceptions/test_http_proxy_auth_failed.py +++ b/tests/http/exceptions/test_http_proxy_auth_failed.py @@ -8,9 +8,10 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import pytest import selectors -import unittest -from unittest import mock + +from pytest_mock import MockerFixture from proxy.common.flag import FlagParser from proxy.http.exception.proxy_auth_failed import ProxyAuthenticationFailed @@ -18,33 +19,33 @@ from proxy.core.connection import TcpClientConnection from proxy.common.utils import build_http_request +from ...test_assertions import Assertions + -class TestHttpProxyAuthFailed(unittest.TestCase): +class TestHttpProxyAuthFailed(Assertions): - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def setUp( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: - self.mock_fromfd = mock_fromfd - self.mock_selector = mock_selector + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.mock_server_conn = mocker.patch( + 'proxy.http.proxy.server.TcpServerConnection', + ) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ["--basic-auth", "user:pass"], threaded=True, ) - self._conn = mock_fromfd.return_value + self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_auth_fails_without_cred(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ @@ -54,25 +55,24 @@ def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> Non self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - - self.protocol_handler._run_once() - mock_server_conn.assert_not_called() + await self.protocol_handler._run_once() + self.mock_server_conn.assert_not_called() self.assertEqual(self.protocol_handler.work.has_buffer(), True) self.assertEqual( self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, ) self._conn.send.assert_not_called() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_auth_fails_with_invalid_cred(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ @@ -83,25 +83,24 @@ def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) - self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - - self.protocol_handler._run_once() - mock_server_conn.assert_not_called() + await self.protocol_handler._run_once() + self.mock_server_conn.assert_not_called() self.assertEqual(self.protocol_handler.work.has_buffer(), True) self.assertEqual( self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, ) self._conn.send.assert_not_called() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_auth_works_with_valid_cred(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ @@ -112,21 +111,20 @@ def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) -> self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - - self.protocol_handler._run_once() - mock_server_conn.assert_called_once() + await self.protocol_handler._run_once() + self.mock_server_conn.assert_called_once() self.assertEqual(self.protocol_handler.work.has_buffer(), False) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_auth_works_with_mixed_case_basic_string(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ @@ -137,15 +135,14 @@ def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: m self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - - self.protocol_handler._run_once() - mock_server_conn.assert_called_once() + await self.protocol_handler._run_once() + self.mock_server_conn.assert_called_once() self.assertEqual(self.protocol_handler.work.has_buffer(), False) diff --git a/tests/http/test_http_proxy.py b/tests/http/test_http_proxy.py index fde5da11b0..b1bc964059 100644 --- a/tests/http/test_http_proxy.py +++ b/tests/http/test_http_proxy.py @@ -8,9 +8,10 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import unittest +import pytest import selectors -from unittest import mock + +from pytest_mock import MockerFixture from proxy.common.constants import DEFAULT_HTTP_PORT from proxy.common.flag import FlagParser @@ -21,27 +22,25 @@ from proxy.common.utils import build_http_request -class TestHttpProxyPlugin(unittest.TestCase): +class TestHttpProxyPlugin: - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def setUp( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: - self.mock_fromfd = mock_fromfd - self.mock_selector = mock_selector + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_server_conn = mocker.patch( + 'proxy.http.proxy.server.TcpServerConnection', + ) + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.mock_fromfd = mocker.patch('socket.fromfd') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize(threaded=True) - self.plugin = mock.MagicMock() + self.plugin = mocker.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin], } - self._conn = mock_fromfd.return_value + self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags, @@ -51,11 +50,8 @@ def setUp( def test_proxy_plugin_initialized(self) -> None: self.plugin.assert_called() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proxy_plugin_on_and_before_upstream_connection( - self, - mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_plugin_on_and_before_upstream_connection(self) -> None: self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r @@ -71,8 +67,8 @@ def test_proxy_plugin_on_and_before_upstream_connection( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -80,16 +76,16 @@ def test_proxy_plugin_on_and_before_upstream_connection( )], ] - self.protocol_handler._run_once() - mock_server_conn.assert_called_with('upstream.host', DEFAULT_HTTP_PORT) + await self.protocol_handler._run_once() + + self.mock_server_conn.assert_called_with( + 'upstream.host', DEFAULT_HTTP_PORT, + ) self.plugin.return_value.before_upstream_connection.assert_called() self.plugin.return_value.handle_client_request.assert_called() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proxy_plugin_before_upstream_connection_can_teardown( - self, - mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_plugin_before_upstream_connection_can_teardown(self) -> None: self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException() @@ -103,8 +99,8 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -112,8 +108,8 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( )], ] - self.protocol_handler._run_once() - mock_server_conn.assert_not_called() + await self.protocol_handler._run_once() + self.mock_server_conn.assert_not_called() self.plugin.return_value.before_upstream_connection.assert_called() def test_proxy_plugin_plugins_can_teardown_from_write_to_descriptors(self) -> None: diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index dda98a925e..1fcfa71720 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -8,13 +8,14 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import ssl import uuid -import unittest import socket -import ssl +import pytest import selectors from typing import Any +from pytest_mock import MockerFixture from unittest import mock from proxy.common.constants import DEFAULT_CA_FILE @@ -24,39 +25,28 @@ from proxy.common.utils import build_http_request, bytes_ from proxy.common.flag import FlagParser +from ..test_assertions import Assertions + + +class TestHttpProxyTlsInterception(Assertions): -class TestHttpProxyTlsInterception(unittest.TestCase): - - @mock.patch('ssl.wrap_socket') - @mock.patch('ssl.create_default_context') - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - @mock.patch('proxy.http.proxy.server.gen_public_key') - @mock.patch('proxy.http.proxy.server.gen_csr') - @mock.patch('proxy.http.proxy.server.sign_csr') - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_e2e( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_sign_csr: mock.Mock, - mock_gen_csr: mock.Mock, - mock_gen_public_key: mock.Mock, - mock_server_conn: mock.Mock, - mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + async def test_e2e(self, mocker: MockerFixture) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) - self.mock_fromfd = mock_fromfd - self.mock_selector = mock_selector - self.mock_sign_csr = mock_sign_csr - self.mock_gen_csr = mock_gen_csr - self.mock_gen_public_key = mock_gen_public_key - self.mock_server_conn = mock_server_conn - self.mock_ssl_context = mock_ssl_context - self.mock_ssl_wrap = mock_ssl_wrap + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') + self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') + self.mock_gen_public_key = mocker.patch( + 'proxy.http.proxy.server.gen_public_key', + ) + self.mock_server_conn = mocker.patch( + 'proxy.http.proxy.server.TcpServerConnection', + ) + self.mock_ssl_context = mocker.patch('ssl.create_default_context') + self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True @@ -95,7 +85,7 @@ def mock_connection() -> Any: b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin], b'HttpProxyBasePlugin': [self.proxy_plugin], } - self._conn = mock_fromfd.return_value + self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags, @@ -121,9 +111,11 @@ def mock_connection() -> Any: self._conn.recv.return_value = connect_request # Prepare mocked HttpProtocolHandlerPlugin + async def asyncReturnBool(val: bool) -> bool: + return val self.plugin.return_value.get_descriptors.return_value = ([], []) - self.plugin.return_value.write_to_descriptors.return_value = False - self.plugin.return_value.read_from_descriptors.return_value = False + self.plugin.return_value.write_to_descriptors.return_value = asyncReturnBool(False) + self.plugin.return_value.read_from_descriptors.return_value = asyncReturnBool(False) self.plugin.return_value.on_client_data.side_effect = lambda raw: raw self.plugin.return_value.on_request_complete.return_value = False self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk @@ -139,8 +131,8 @@ def mock_connection() -> Any: self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -148,7 +140,7 @@ def mock_connection() -> Any: )], ] - self.protocol_handler._run_once() + await self.protocol_handler._run_once() # Assert our mocked plugins invocations self.plugin.return_value.get_descriptors.assert_called() @@ -158,7 +150,7 @@ def mock_connection() -> Any: ) self.plugin.return_value.on_request_complete.assert_called() self.plugin.return_value.read_from_descriptors.assert_called_with([ - self._conn, + self._conn.fileno(), ]) self.proxy_plugin.return_value.before_upstream_connection.assert_called() self.proxy_plugin.return_value.handle_client_request.assert_called() diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index c37740a4f3..ca6cce944e 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -8,12 +8,13 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import unittest -import selectors import base64 +import pytest +import selectors -from typing import cast from unittest import mock +from pytest_mock import MockerFixture +from typing import cast, Any from proxy.common.plugins import Plugins from proxy.common.flag import FlagParser @@ -27,19 +28,103 @@ from proxy.http.exception import ProxyAuthenticationFailed, ProxyConnectionFailed from proxy.http import HttpProtocolHandler +from ..test_assertions import Assertions -class TestHttpProtocolHandler(unittest.TestCase): - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def setUp( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: +def mock_selector_for_client_read(self: Any) -> None: + self.mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] + + +class TestHttpProtocolHandlerWithoutServerMock(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self._conn = self.mock_fromfd.return_value + + self.http_server_port = 65535 + self.flags = FlagParser.initialize(threaded=True) + self.flags.plugins = Plugins.load([ + bytes_(PLUGIN_HTTP_PROXY), + bytes_(PLUGIN_WEB_SERVER), + ]) + + self.protocol_handler = HttpProtocolHandler( + TcpClientConnection(self._conn, self._addr), + flags=self.flags, + ) + self.protocol_handler.initialize() + + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_connection_failed(self) -> None: + mock_selector_for_client_read(self) + self._conn.recv.return_value = CRLF.join([ + b'GET http://unknown.domain HTTP/1.1', + b'Host: unknown.domain', + CRLF, + ]) + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.work.buffer[0], + ProxyConnectionFailed.RESPONSE_PKT, + ) + + @pytest.mark.asyncio # type: ignore[misc] + async def test_proxy_authentication_failed(self) -> None: + self._conn = self.mock_fromfd.return_value + mock_selector_for_client_read(self) + flags = FlagParser.initialize( + auth_code=base64.b64encode(b'user:pass'), + threaded=True, + ) + flags.plugins = Plugins.load([ + bytes_(PLUGIN_HTTP_PROXY), + bytes_(PLUGIN_WEB_SERVER), + bytes_(PLUGIN_PROXY_AUTH), + ]) + self.protocol_handler = HttpProtocolHandler( + TcpClientConnection(self._conn, self._addr), flags=flags, + ) + self.protocol_handler.initialize() + self._conn.recv.return_value = CRLF.join([ + b'GET http://abhinavsingh.com HTTP/1.1', + b'Host: abhinavsingh.com', + CRLF, + ]) + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.work.buffer[0], + ProxyAuthenticationFailed.RESPONSE_PKT, + ) + + +class TestHttpProtocolHandler(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.mock_server_connection = mocker.patch( + 'proxy.http.proxy.server.TcpServerConnection', + ) + self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = mock_fromfd.return_value + self._conn = self.mock_fromfd.return_value self.http_server_port = 65535 self.flags = FlagParser.initialize(threaded=True) @@ -48,20 +133,19 @@ def setUp( bytes_(PLUGIN_WEB_SERVER), ]) - self.mock_selector = mock_selector self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=self.flags, + TcpClientConnection(self._conn, self._addr), + flags=self.flags, ) self.protocol_handler.initialize() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_http_get(self, mock_server_connection: mock.Mock) -> None: - server = mock_server_connection.return_value + @pytest.mark.asyncio # type: ignore[misc] + async def test_http_get(self) -> None: + server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 - self.mock_selector_for_client_read_read_server_write( - self.mock_selector, server, - ) + + self.mock_selector_for_client_read_and_server_write(server) # Send request line assert self.http_server_port is not None @@ -69,7 +153,9 @@ def test_http_get(self, mock_server_connection: mock.Mock) -> None: b'GET http://localhost:%d HTTP/1.1' % self.http_server_port ) + CRLF - self.protocol_handler._run_once() + + await self.protocol_handler._run_once() + self.assertEqual( self.protocol_handler.request.state, httpParserStates.LINE_RCVD, @@ -88,14 +174,15 @@ def test_http_get(self, mock_server_connection: mock.Mock) -> None: b'Proxy-Connection: Keep-Alive', CRLF, ]) - self.assert_data_queued(mock_server_connection, server) - self.protocol_handler._run_once() + await self.assert_data_queued(server) + await self.protocol_handler._run_once() server.flush.assert_called_once() - def assert_tunnel_response( - self, mock_server_connection: mock.Mock, server: mock.Mock, + async def assert_tunnel_response( + self, + server: mock.Mock, ) -> None: - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertTrue( cast( HttpProxyPlugin, @@ -106,7 +193,7 @@ def assert_tunnel_response( self.protocol_handler.work.buffer[0], HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) - mock_server_connection.assert_called_once() + self.mock_server_connection.assert_called_once() server.connect.assert_called_once() server.queue.assert_not_called() server.closed = False @@ -117,9 +204,9 @@ def assert_tunnel_response( assert parser.code is not None self.assertEqual(int(parser.code), 200) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_http_tunnel(self, mock_server_connection: mock.Mock) -> None: - server = mock_server_connection.return_value + @pytest.mark.asyncio # type: ignore[misc] + async def test_http_tunnel(self) -> None: + server = self.mock_server_connection.return_value server.connect.return_value = True def has_buffer() -> bool: @@ -130,8 +217,8 @@ def has_buffer() -> bool: [ ( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -141,8 +228,8 @@ def has_buffer() -> bool: [ ( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=0, data=None, ), @@ -152,8 +239,8 @@ def has_buffer() -> bool: [ ( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -163,8 +250,8 @@ def has_buffer() -> bool: [ ( selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, + fileobj=server.connection.fileno(), + fd=server.connection.fileno(), events=0, data=None, ), @@ -181,74 +268,22 @@ def has_buffer() -> bool: b'Proxy-Connection: Keep-Alive', CRLF, ]) - self.assert_tunnel_response(mock_server_connection, server) + await self.assert_tunnel_response(server) # Dispatch tunnel established response to client - self.protocol_handler._run_once() - self.assert_data_queued_to_server(server) + await self.protocol_handler._run_once() + await self.assert_data_queued_to_server(server) - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertEqual(server.queue.call_count, 1) server.flush.assert_called_once() - def test_proxy_connection_failed(self) -> None: - self.mock_selector_for_client_read(self.mock_selector) - self._conn.recv.return_value = CRLF.join([ - b'GET http://unknown.domain HTTP/1.1', - b'Host: unknown.domain', - CRLF, - ]) - self.protocol_handler._run_once() - self.assertEqual( - self.protocol_handler.work.buffer[0], - ProxyConnectionFailed.RESPONSE_PKT, - ) - - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_proxy_authentication_failed( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: - self._conn = mock_fromfd.return_value - self.mock_selector_for_client_read(mock_selector) - flags = FlagParser.initialize( - auth_code=base64.b64encode(b'user:pass'), - threaded=True, - ) - flags.plugins = Plugins.load([ - bytes_(PLUGIN_HTTP_PROXY), - bytes_(PLUGIN_WEB_SERVER), - bytes_(PLUGIN_PROXY_AUTH), - ]) - self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), flags=flags, - ) - self.protocol_handler.initialize() - self._conn.recv.return_value = CRLF.join([ - b'GET http://abhinavsingh.com HTTP/1.1', - b'Host: abhinavsingh.com', - CRLF, - ]) - self.protocol_handler._run_once() - self.assertEqual( - self.protocol_handler.work.buffer[0], - ProxyAuthenticationFailed.RESPONSE_PKT, - ) - - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_authenticated_proxy_http_get( - self, mock_server_connection: mock.Mock, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: - self._conn = mock_fromfd.return_value - self.mock_selector_for_client_read(mock_selector) + @pytest.mark.asyncio # type: ignore[misc] + async def test_authenticated_proxy_http_get(self) -> None: + self._conn = self.mock_fromfd.return_value + mock_selector_for_client_read(self) - server = mock_server_connection.return_value + server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 @@ -268,14 +303,14 @@ def test_authenticated_proxy_http_get( assert self.http_server_port is not None self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.INITIALIZED, ) self._conn.recv.return_value = CRLF - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.LINE_RCVD, @@ -290,23 +325,15 @@ def test_authenticated_proxy_http_get( b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF, ]) - self.assert_data_queued(mock_server_connection, server) - - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_authenticated_proxy_http_tunnel( - self, mock_server_connection: mock.Mock, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: - server = mock_server_connection.return_value + await self.assert_data_queued(server) + + @pytest.mark.asyncio # type: ignore[misc] + async def test_authenticated_proxy_http_tunnel(self) -> None: + server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 - self._conn = mock_fromfd.return_value - self.mock_selector_for_client_read_read_server_write( - mock_selector, server, - ) + self._conn = self.mock_fromfd.return_value + self.mock_selector_for_client_read_and_server_write(server) flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), @@ -331,22 +358,22 @@ def test_authenticated_proxy_http_tunnel( b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF, ]) - self.assert_tunnel_response(mock_server_connection, server) + await self.assert_tunnel_response(server) self.protocol_handler.work.flush() - self.assert_data_queued_to_server(server) + await self.assert_data_queued_to_server(server) - self.protocol_handler._run_once() + await self.protocol_handler._run_once() server.flush.assert_called_once() - def mock_selector_for_client_read_read_server_write( - self, mock_selector: mock.Mock, server: mock.Mock, + def mock_selector_for_client_read_and_server_write( + self, server: mock.Mock, ) -> None: - mock_selector.return_value.select.side_effect = [ + self.mock_selector.return_value.select.side_effect = [ [ ( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -356,8 +383,8 @@ def mock_selector_for_client_read_read_server_write( [ ( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=0, data=None, ), @@ -367,8 +394,8 @@ def mock_selector_for_client_read_read_server_write( [ ( selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, + fileobj=server.connection.fileno(), + fd=server.connection.fileno(), events=0, data=None, ), @@ -377,15 +404,15 @@ def mock_selector_for_client_read_read_server_write( ], ] - def assert_data_queued( - self, mock_server_connection: mock.Mock, server: mock.Mock, + async def assert_data_queued( + self, server: mock.Mock, ) -> None: - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, ) - mock_server_connection.assert_called_once() + self.mock_server_connection.assert_called_once() server.connect.assert_called_once() server.closed = False assert self.http_server_port is not None @@ -401,7 +428,7 @@ def assert_data_queued( self.assertEqual(server.queue.call_args_list[0][0][0].tobytes(), pkt) server.buffer_size.return_value = len(pkt) - def assert_data_queued_to_server(self, server: mock.Mock) -> None: + async def assert_data_queued_to_server(self, server: mock.Mock) -> None: assert self.http_server_port is not None self.assertEqual( self._conn.send.call_args[0][0], @@ -416,21 +443,8 @@ def assert_data_queued_to_server(self, server: mock.Mock) -> None: ]) self._conn.recv.return_value = pkt - self.protocol_handler._run_once() + await self.protocol_handler._run_once() server.queue.assert_called_once_with(pkt) server.buffer_size.return_value = len(pkt) server.flush.assert_not_called() - - def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: - mock_selector.return_value.select.return_value = [ - ( - selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None, - ), - selectors.EVENT_READ, - ), - ] diff --git a/tests/http/test_web_server.py b/tests/http/test_web_server.py index be0098128a..54a1d49060 100644 --- a/tests/http/test_web_server.py +++ b/tests/http/test_web_server.py @@ -8,81 +8,173 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import gzip import os +import gzip +import pytest import tempfile -import unittest import selectors -from unittest import mock + +from typing import Any +from pytest_mock import MockerFixture +# from unittest import mock from proxy.common.plugins import Plugins from proxy.common.flag import FlagParser from proxy.core.connection import TcpClientConnection from proxy.http import HttpProtocolHandler from proxy.http.parser import HttpParser, httpParserStates, httpParserTypes -from proxy.common.utils import build_http_response, build_http_request, bytes_, text_ +from proxy.common.utils import build_http_response, build_http_request, bytes_ from proxy.common.constants import CRLF, PLUGIN_HTTP_PROXY, PLUGIN_PAC_FILE, PLUGIN_WEB_SERVER, PROXY_PY_DIR from proxy.http.server import HttpWebServerPlugin +from ..test_assertions import Assertions + + +PAC_FILE_PATH = os.path.join( + os.path.dirname(PROXY_PY_DIR), + 'helper', + 'proxy.pac', +) -class TestWebServerPlugin(unittest.TestCase): +PAC_FILE_CONTENT = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + +def test_on_client_connection_called_on_teardown(mocker: MockerFixture) -> None: + plugin = mocker.MagicMock() + mock_fromfd = mocker.patch('socket.fromfd') + flags = FlagParser.initialize(threaded=True) + flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} + _conn = mock_fromfd.return_value + _addr = ('127.0.0.1', 54382) + protocol_handler = HttpProtocolHandler( + TcpClientConnection(_conn, _addr), + flags=flags, + ) + protocol_handler.initialize() + plugin.assert_called() + mock_run_once = mocker.patch.object(protocol_handler, '_run_once') + mock_run_once.return_value = True + protocol_handler.run() + assert _conn.closed + plugin.return_value.on_client_connection_close.assert_called() + + +def mock_selector_for_client_read(self: Any) -> None: + self.mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] + + # @mock.patch('socket.fromfd') + # def test_on_client_connection_called_on_teardown( + # self, mock_fromfd: mock.Mock, + # ) -> None: + # flags = FlagParser.initialize(threaded=True) + # plugin = mock.MagicMock() + # flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} + # self._conn = mock_fromfd.return_value + # self.protocol_handler = HttpProtocolHandler( + # TcpClientConnection(self._conn, self._addr), + # flags=flags, + # ) + # self.protocol_handler.initialize() + # plugin.assert_called() + # with mock.patch.object(self.protocol_handler, '_run_once') as mock_run_once: + # mock_run_once.return_value = True + # self.protocol_handler.run() + # self.assertTrue(self._conn.closed) + # plugin.return_value.on_client_connection_close.assert_called() + + # @mock.patch('socket.fromfd') + # def test_on_client_connection_called_on_teardown( + # self, mock_fromfd: mock.Mock, + # ) -> None: + # flags = FlagParser.initialize(threaded=True) + # plugin = mock.MagicMock() + # flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} + # self._conn = mock_fromfd.return_value + # self.protocol_handler = HttpProtocolHandler( + # TcpClientConnection(self._conn, self._addr), + # flags=flags, + # ) + # self.protocol_handler.initialize() + # plugin.assert_called() + # with mock.patch.object(self.protocol_handler, '_run_once') as mock_run_once: + # mock_run_once.return_value = True + # self.protocol_handler.run() + # self.assertTrue(self._conn.closed) + # plugin.return_value.on_client_connection_close.assert_called() + + # @mock.patch('socket.fromfd') + # def test_on_client_connection_called_on_teardown( + # self, mock_fromfd: mock.Mock, + # ) -> None: + # flags = FlagParser.initialize(threaded=True) + # plugin = mock.MagicMock() + # flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} + # self._conn = mock_fromfd.return_value + # self.protocol_handler = HttpProtocolHandler( + # TcpClientConnection(self._conn, self._addr), + # flags=flags, + # ) + # self.protocol_handler.initialize() + # plugin.assert_called() + # with mock.patch.object(self.protocol_handler, '_run_once') as mock_run_once: + # mock_run_once.return_value = True + # self.protocol_handler.run() + # self.assertTrue(self._conn.closed) + # plugin.return_value.on_client_connection_close.assert_called() + + +class TestWebServerPluginWithPacFilePlugin(Assertions): + + @pytest.fixture( + autouse=True, params=[ + PAC_FILE_PATH, + PAC_FILE_CONTENT, + ], + ) # type: ignore[misc] + def _setUp(self, request: Any, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = mock_fromfd.return_value - self.mock_selector = mock_selector - self.flags = FlagParser.initialize(threaded=True) + self._conn = self.mock_fromfd.return_value + self.pac_file = request.param + if isinstance(self.pac_file, str): + with open(self.pac_file, 'rb') as f: + self.expected_response = f.read() + else: + self.expected_response = PAC_FILE_CONTENT + self.flags = FlagParser.initialize( + pac_file=self.pac_file, threaded=True, + ) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), + bytes_(PLUGIN_PAC_FILE), ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() + self._conn.recv.return_value = CRLF.join([ + b'GET / HTTP/1.1', + CRLF, + ]) + mock_selector_for_client_read(self) - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_pac_file_served_from_disk( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - ) -> None: - pac_file = os.path.join( - os.path.dirname(PROXY_PY_DIR), - 'helper', - 'proxy.pac', - ) - self._conn = mock_fromfd.return_value - self.mock_selector_for_client_read(mock_selector) - self.init_and_make_pac_file_request(pac_file) - self.protocol_handler._run_once() - self.assertEqual( - self.protocol_handler.request.state, - httpParserStates.COMPLETE, - ) - with open(pac_file, 'rb') as f: - self._conn.send.called_once_with( - build_http_response( - 200, reason=b'OK', headers={ - b'Content-Type': b'application/x-ns-proxy-autoconfig', - b'Connection': b'close', - }, body=f.read(), - ), - ) - - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_pac_file_served_from_buffer( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - ) -> None: - self._conn = mock_fromfd.return_value - self.mock_selector_for_client_read(mock_selector) - pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' - self.init_and_make_pac_file_request(text_(pac_file_content)) - self.protocol_handler._run_once() + @pytest.mark.asyncio # type: ignore[misc] + async def test_pac_file_served_from_disk(self) -> None: + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, @@ -92,28 +184,35 @@ def test_pac_file_served_from_buffer( 200, reason=b'OK', headers={ b'Content-Type': b'application/x-ns-proxy-autoconfig', b'Connection': b'close', - }, body=pac_file_content, + }, body=self.expected_response, ), ) - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_default_web_server_returns_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - ) -> None: - self._conn = mock_fromfd.return_value - mock_selector.return_value.select.return_value = [ - ( - selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, - events=selectors.EVENT_READ, - data=None, - ), - selectors.EVENT_READ, - ), - ] - flags = FlagParser.initialize(threaded=True) + +class TestStaticWebServerPlugin(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self._conn = self.mock_fromfd.return_value + # Setup a static directory + self.static_server_dir = os.path.join(tempfile.gettempdir(), 'static') + self.index_file_path = os.path.join( + self.static_server_dir, 'index.html', + ) + self.html_file_content = b'''

Proxy.py Testing

''' + os.makedirs(self.static_server_dir, exist_ok=True) + with open(self.index_file_path, 'wb') as f: + f.write(self.html_file_content) + # + flags = FlagParser.initialize( + enable_static_server=True, + static_server_dir=self.static_server_dir, + threaded=True, + ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), @@ -123,43 +222,17 @@ def test_default_web_server_returns_404( flags=flags, ) self.protocol_handler.initialize() - self._conn.recv.return_value = CRLF.join([ - b'GET /hello HTTP/1.1', - CRLF, - ]) - self.protocol_handler._run_once() - self.assertEqual( - self.protocol_handler.request.state, - httpParserStates.COMPLETE, - ) - self.assertEqual( - self.protocol_handler.work.buffer[0], - HttpWebServerPlugin.DEFAULT_404_RESPONSE, - ) - - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_static_web_server_serves( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - ) -> None: - # Setup a static directory - static_server_dir = os.path.join(tempfile.gettempdir(), 'static') - index_file_path = os.path.join(static_server_dir, 'index.html') - html_file_content = b'''

Proxy.py Testing

''' - os.makedirs(static_server_dir, exist_ok=True) - with open(index_file_path, 'wb') as f: - f.write(html_file_content) - self._conn = mock_fromfd.return_value + @pytest.mark.asyncio # type: ignore[misc] + async def test_static_web_server_serves(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'/index.html', ) - - mock_selector.return_value.select.side_effect = [ + self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -167,37 +240,20 @@ def test_static_web_server_serves( )], [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_WRITE, data=None, ), selectors.EVENT_WRITE, )], ] + await self.protocol_handler._run_once() + await self.protocol_handler._run_once() - flags = FlagParser.initialize( - enable_static_server=True, - static_server_dir=static_server_dir, - threaded=True, - ) - flags.plugins = Plugins.load([ - bytes_(PLUGIN_HTTP_PROXY), - bytes_(PLUGIN_WEB_SERVER), - ]) - - self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), - flags=flags, - ) - self.protocol_handler.initialize() - - self.protocol_handler._run_once() - self.protocol_handler._run_once() - - self.assertEqual(mock_selector.return_value.select.call_count, 2) + self.assertEqual(self.mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) - encoded_html_file_content = gzip.compress(html_file_content) + encoded_html_file_content = gzip.compress(self.html_file_content) # parse response and verify response = HttpParser(httpParserTypes.RESPONSE_PARSER) @@ -212,25 +268,21 @@ def test_static_web_server_serves( bytes_(len(encoded_html_file_content)), ) assert response.body - self.assertEqual(gzip.decompress(response.body), html_file_content) + self.assertEqual( + gzip.decompress(response.body), + self.html_file_content, + ) - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def test_static_web_server_serves_404( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: - self._conn = mock_fromfd.return_value + @pytest.mark.asyncio # type: ignore[misc] + async def test_static_web_server_serves_404(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'/not-found.html', ) - - mock_selector.return_value.select.side_effect = [ + self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -238,8 +290,8 @@ def test_static_web_server_serves_404( )], [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_WRITE, data=None, ), @@ -247,74 +299,71 @@ def test_static_web_server_serves_404( )], ] - flags = FlagParser.initialize(enable_static_server=True, threaded=True) - flags.plugins = Plugins.load([ - bytes_(PLUGIN_HTTP_PROXY), - bytes_(PLUGIN_WEB_SERVER), - ]) - - self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), - flags=flags, - ) - self.protocol_handler.initialize() - - self.protocol_handler._run_once() - self.protocol_handler._run_once() + await self.protocol_handler._run_once() + await self.protocol_handler._run_once() - self.assertEqual(mock_selector.return_value.select.call_count, 2) + self.assertEqual(self.mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) self.assertEqual( self._conn.send.call_args[0][0], HttpWebServerPlugin.DEFAULT_404_RESPONSE, ) - @mock.patch('socket.fromfd') - def test_on_client_connection_called_on_teardown( - self, mock_fromfd: mock.Mock, - ) -> None: - flags = FlagParser.initialize(threaded=True) - plugin = mock.MagicMock() - flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} - self._conn = mock_fromfd.return_value - self.protocol_handler = HttpProtocolHandler( - TcpClientConnection(self._conn, self._addr), - flags=flags, - ) - self.protocol_handler.initialize() - plugin.assert_called() - with mock.patch.object(self.protocol_handler, '_run_once') as mock_run_once: - mock_run_once.return_value = True - self.protocol_handler.run() - self.assertTrue(self._conn.closed) - plugin.return_value.on_client_connection_close.assert_called() - def init_and_make_pac_file_request(self, pac_file: str) -> None: - flags = FlagParser.initialize(pac_file=pac_file, threaded=True) - flags.plugins = Plugins.load([ +class TestWebServerPlugin(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self._conn = self.mock_fromfd.return_value + self.flags = FlagParser.initialize(threaded=True) + self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), - bytes_(PLUGIN_PAC_FILE), ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), - flags=flags, + flags=self.flags, ) self.protocol_handler.initialize() - self._conn.recv.return_value = CRLF.join([ - b'GET / HTTP/1.1', - CRLF, - ]) - def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: - mock_selector.return_value.select.return_value = [ + @pytest.mark.asyncio # type: ignore[misc] + async def test_default_web_server_returns_404(self) -> None: + self._conn = self.mock_fromfd.return_value + self.mock_selector.return_value.select.return_value = [ ( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, ), ] + flags = FlagParser.initialize(threaded=True) + flags.plugins = Plugins.load([ + bytes_(PLUGIN_HTTP_PROXY), + bytes_(PLUGIN_WEB_SERVER), + ]) + self.protocol_handler = HttpProtocolHandler( + TcpClientConnection(self._conn, self._addr), + flags=flags, + ) + self.protocol_handler.initialize() + self._conn.recv.return_value = CRLF.join([ + b'GET /hello HTTP/1.1', + CRLF, + ]) + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.request.state, + httpParserStates.COMPLETE, + ) + self.assertEqual( + self.protocol_handler.work.buffer[0], + HttpWebServerPlugin.DEFAULT_404_RESPONSE, + ) diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index c1195b9871..9df24fa8d1 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1,7 +1,18 @@ -"""Test the simplest proxy use scenario for smoke.""" +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + Test the simplest proxy use scenario for smoke. +""" from pathlib import Path from subprocess import check_output, Popen -from typing import Generator +from typing import Generator, Any import pytest @@ -9,32 +20,33 @@ from proxy.common._compat import IS_WINDOWS # noqa: WPS436 -PROXY_PY_PORT = get_available_port() - - # FIXME: Ignore is necessary for as long as pytest hasn't figured out # FIXME: typing for their fixtures. # Refs: # * https://github.com/pytest-dev/pytest/issues/7469#issuecomment-918345196 # * https://github.com/pytest-dev/pytest/issues/3342 @pytest.fixture # type: ignore[misc] -def _proxy_py_instance() -> Generator[None, None, None]: +def proxy_py_subprocess(request: Any) -> Generator[int, None, None]: """Instantiate proxy.py in a subprocess for testing. + NOTE: Doesn't waits for the proxy to startup. + Ensure instance check in your tests. + After the testing is over, tear it down. """ + port = get_available_port() proxy_cmd = ( 'python', '-m', 'proxy', '--hostname', '127.0.0.1', - '--port', str(PROXY_PY_PORT), + '--port', str(port), '--enable-web-server', - ) + ) + tuple(request.param.split()) proxy_proc = Popen(proxy_cmd) try: - yield + yield port finally: proxy_proc.terminate() - proxy_proc.wait(1) + proxy_proc.wait() # FIXME: Ignore is necessary for as long as pytest hasn't figured out @@ -43,15 +55,22 @@ def _proxy_py_instance() -> Generator[None, None, None]: # * https://github.com/pytest-dev/pytest/issues/7469#issuecomment-918345196 # * https://github.com/pytest-dev/pytest/issues/3342 @pytest.mark.smoke # type: ignore[misc] -@pytest.mark.usefixtures('_proxy_py_instance') # type: ignore[misc] +@pytest.mark.parametrize( + 'proxy_py_subprocess', + ( + ('--threadless'), + ('--threadless --local-executor'), + ('--threaded'), + ), + indirect=True, +) # type: ignore[misc] @pytest.mark.xfail( IS_WINDOWS, reason='OSError: [WinError 193] %1 is not a valid Win32 application', raises=OSError, ) # type: ignore[misc] -def test_curl() -> None: +def test_curl(proxy_py_subprocess: int) -> None: """An acceptance test with using ``curl`` through proxy.py.""" this_test_module = Path(__file__) shell_script_test = this_test_module.with_suffix('.sh') - - check_output([str(shell_script_test), str(PROXY_PY_PORT)]) + check_output([str(shell_script_test), str(proxy_py_subprocess)]) diff --git a/tests/plugin/test_http_proxy_plugins.py b/tests/plugin/test_http_proxy_plugins.py index 864f48597a..7102c3ef7a 100644 --- a/tests/plugin/test_http_proxy_plugins.py +++ b/tests/plugin/test_http_proxy_plugins.py @@ -8,14 +8,15 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import unittest -import selectors import json +import pytest +import selectors -from urllib import parse as urlparse -from unittest import mock -from typing import cast from pathlib import Path +from unittest import mock +from typing import cast, Any +from urllib import parse as urlparse +from pytest_mock import MockerFixture from proxy.common.flag import FlagParser from proxy.core.connection import TcpClientConnection @@ -24,21 +25,23 @@ from proxy.http.proxy import HttpProxyPlugin from proxy.common.utils import build_http_request, bytes_, build_http_response from proxy.common.constants import PROXY_AGENT_HEADER_VALUE, DEFAULT_HTTP_PORT - from proxy.plugin import ProposedRestApiPlugin, RedirectToCustomServerPlugin from .utils import get_plugin_by_test_name +from ..test_assertions import Assertions -class TestHttpProxyPluginExamples(unittest.TestCase): - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def setUp( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - ) -> None: +class TestHttpProxyPluginExamples(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, request: Any, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.mock_server_conn = mocker.patch( + 'proxy.http.proxy.server.TcpServerConnection', + ) + self.fileno = 10 self._addr = ('127.0.0.1', 54382) adblock_json_path = Path( @@ -53,26 +56,28 @@ def setUp( ) self.plugin = mock.MagicMock() - self.mock_fromfd = mock_fromfd - self.mock_selector = mock_selector - - plugin = get_plugin_by_test_name(self._testMethodName) + plugin = get_plugin_by_test_name(request.param) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } - self._conn = mock_fromfd.return_value + self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_modify_post_data_plugin( - self, mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_modify_post_data_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_modify_post_data_plugin(self) -> None: original = b'{"key": "value"}' modified = b'{"key": "modified"}' @@ -88,8 +93,8 @@ def test_modify_post_data_plugin( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -97,9 +102,11 @@ def test_modify_post_data_plugin( )], ] - self.protocol_handler._run_once() - mock_server_conn.assert_called_with('httpbin.org', DEFAULT_HTTP_PORT) - mock_server_conn.return_value.queue.assert_called_with( + await self.protocol_handler._run_once() + self.mock_server_conn.assert_called_with( + 'httpbin.org', DEFAULT_HTTP_PORT, + ) + self.mock_server_conn.return_value.queue.assert_called_with( build_http_request( b'POST', b'/post', headers={ @@ -112,10 +119,15 @@ def test_modify_post_data_plugin( ), ) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_proposed_rest_api_plugin( - self, mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_proposed_rest_api_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_proposed_rest_api_plugin(self) -> None: path = b'/v1/users/' self._conn.recv.return_value = build_http_request( b'GET', b'http://%s%s' % ( @@ -128,17 +140,17 @@ def test_proposed_rest_api_plugin( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - self.protocol_handler._run_once() + await self.protocol_handler._run_once() - mock_server_conn.assert_not_called() + self.mock_server_conn.assert_not_called() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), build_http_response( @@ -152,10 +164,15 @@ def test_proposed_rest_api_plugin( ), ) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_redirect_to_custom_server_plugin( - self, mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_redirect_to_custom_server_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_redirect_to_custom_server_plugin(self) -> None: request = build_http_request( b'GET', b'http://example.org/get', headers={ @@ -166,21 +183,21 @@ def test_redirect_to_custom_server_plugin( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - self.protocol_handler._run_once() + await self.protocol_handler._run_once() upstream = urlparse.urlsplit( RedirectToCustomServerPlugin.UPSTREAM_SERVER, ) - mock_server_conn.assert_called_with('localhost', 8899) - mock_server_conn.return_value.queue.assert_called_with( + self.mock_server_conn.assert_called_with('localhost', 8899) + self.mock_server_conn.return_value.queue.assert_called_with( build_http_request( b'GET', upstream.path, headers={ @@ -190,10 +207,15 @@ def test_redirect_to_custom_server_plugin( ), ) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_filter_by_upstream_host_plugin( - self, mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_filter_by_upstream_host_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_filter_by_upstream_host_plugin(self) -> None: request = build_http_request( b'GET', b'http://facebook.com/', headers={ @@ -204,17 +226,17 @@ def test_filter_by_upstream_host_plugin( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - self.protocol_handler._run_once() + await self.protocol_handler._run_once() - mock_server_conn.assert_not_called() + self.mock_server_conn.assert_not_called() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), build_http_response( @@ -226,10 +248,15 @@ def test_filter_by_upstream_host_plugin( ), ) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_man_in_the_middle_plugin( - self, mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_man_in_the_middle_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_man_in_the_middle_plugin(self) -> None: request = build_http_request( b'GET', b'http://super.secure/', headers={ @@ -238,7 +265,7 @@ def test_man_in_the_middle_plugin( ) self._conn.recv.return_value = request - server = mock_server_conn.return_value + server = self.mock_server_conn.return_value server.connect.return_value = True def has_buffer() -> bool: @@ -253,8 +280,8 @@ def closed() -> bool: self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -262,8 +289,8 @@ def closed() -> bool: )], [( selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, + fileobj=server.connection.fileno(), + fd=server.connection.fileno(), events=selectors.EVENT_WRITE, data=None, ), @@ -271,8 +298,8 @@ def closed() -> bool: )], [( selectors.SelectorKey( - fileobj=server.connection, - fd=server.connection.fileno, + fileobj=server.connection.fileno(), + fd=server.connection.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -281,8 +308,10 @@ def closed() -> bool: ] # Client read - self.protocol_handler._run_once() - mock_server_conn.assert_called_with('super.secure', DEFAULT_HTTP_PORT) + await self.protocol_handler._run_once() + self.mock_server_conn.assert_called_with( + 'super.secure', DEFAULT_HTTP_PORT, + ) server.connect.assert_called_once() queued_request = \ build_http_request( @@ -295,7 +324,7 @@ def closed() -> bool: server.queue.assert_called_once_with(queued_request) # Server write - self.protocol_handler._run_once() + await self.protocol_handler._run_once() server.flush.assert_called_once() # Server read @@ -304,7 +333,7 @@ def closed() -> bool: httpStatusCodes.OK, reason=b'OK', body=b'Original Response From Upstream', ) - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), build_http_response( @@ -313,10 +342,15 @@ def closed() -> bool: ), ) - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - def test_filter_by_url_regex_plugin( - self, mock_server_conn: mock.Mock, - ) -> None: + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_filter_by_url_regex_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_filter_by_url_regex_plugin(self) -> None: request = build_http_request( b'GET', b'http://www.facebook.com/tr/', headers={ @@ -327,15 +361,15 @@ def test_filter_by_url_regex_plugin( self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index d076f1558a..232b0dd954 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -8,12 +8,12 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import unittest +import ssl import socket +import pytest import selectors -import ssl -from unittest import mock +from pytest_mock import MockerFixture from typing import Any, cast from proxy.common.flag import FlagParser @@ -22,39 +22,29 @@ from proxy.http import httpMethods, httpStatusCodes, HttpProtocolHandler from proxy.http.proxy import HttpProxyPlugin +from proxy.http.parser import HttpParser from .utils import get_plugin_by_test_name +from ..test_assertions import Assertions -class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): - - @mock.patch('ssl.wrap_socket') - @mock.patch('ssl.create_default_context') - @mock.patch('proxy.http.proxy.server.TcpServerConnection') - @mock.patch('proxy.http.proxy.server.gen_public_key') - @mock.patch('proxy.http.proxy.server.gen_csr') - @mock.patch('proxy.http.proxy.server.sign_csr') - @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') - def setUp( - self, - mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_sign_csr: mock.Mock, - mock_gen_csr: mock.Mock, - mock_gen_public_key: mock.Mock, - mock_server_conn: mock.Mock, - mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock, - ) -> None: - self.mock_fromfd = mock_fromfd - self.mock_selector = mock_selector - self.mock_sign_csr = mock_sign_csr - self.mock_gen_csr = mock_gen_csr - self.mock_gen_public_key = mock_gen_public_key - self.mock_server_conn = mock_server_conn - self.mock_ssl_context = mock_ssl_context - self.mock_ssl_wrap = mock_ssl_wrap + +class TestHttpProxyPluginExamplesWithTlsInterception(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, request: Any, mocker: MockerFixture) -> None: + self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') + self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') + self.mock_gen_public_key = mocker.patch( + 'proxy.http.proxy.server.gen_public_key', + ) + self.mock_server_conn = mocker.patch( + 'proxy.http.proxy.server.TcpServerConnection', + ) + self.mock_ssl_context = mocker.patch('ssl.create_default_context') + self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True @@ -68,16 +58,16 @@ def setUp( ca_signing_key_file='ca-signing-key.pem', threaded=True, ) - self.plugin = mock.MagicMock() + self.plugin = mocker.MagicMock() - plugin = get_plugin_by_test_name(self._testMethodName) + plugin = get_plugin_by_test_name(request.param) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } - self._conn = mock.MagicMock(spec=socket.socket) - mock_fromfd.return_value = self._conn + self._conn = mocker.MagicMock(spec=socket.socket) + self.mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags, ) @@ -85,9 +75,9 @@ def setUp( self.server = self.mock_server_conn.return_value - self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) + self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection - self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) + self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = self.client_ssl_connection def has_buffer() -> bool: @@ -106,18 +96,18 @@ def mock_connection() -> Any: lambda x, y: TcpServerConnection.wrap(self.server, x, y) self.server.has_buffer.side_effect = has_buffer - type(self.server).closed = mock.PropertyMock(side_effect=closed) + type(self.server).closed = mocker.PropertyMock(side_effect=closed) type( self.server, - ).connection = mock.PropertyMock( + ).connection = mocker.PropertyMock( side_effect=mock_connection, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( - fileobj=self._conn, - fd=self._conn.fileno, + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -125,8 +115,8 @@ def mock_connection() -> Any: )], [( selectors.SelectorKey( - fileobj=self.client_ssl_connection, - fd=self.client_ssl_connection.fileno, + fileobj=self.client_ssl_connection.fileno(), + fd=self.client_ssl_connection.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -134,8 +124,8 @@ def mock_connection() -> Any: )], [( selectors.SelectorKey( - fileobj=self.server_ssl_connection, - fd=self.server_ssl_connection.fileno, + fileobj=self.server_ssl_connection.fileno(), + fd=self.server_ssl_connection.fileno(), events=selectors.EVENT_WRITE, data=None, ), @@ -143,8 +133,8 @@ def mock_connection() -> Any: )], [( selectors.SelectorKey( - fileobj=self.server_ssl_connection, - fd=self.server_ssl_connection.fileno, + fileobj=self.server_ssl_connection.fileno(), + fd=self.server_ssl_connection.fileno(), events=selectors.EVENT_READ, data=None, ), @@ -160,7 +150,17 @@ def send(raw: bytes) -> int: self._conn.recv.return_value = build_http_request( httpMethods.CONNECT, b'uni.corn:443', ) - self.protocol_handler._run_once() + + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + '_setUp', + ( + ('test_modify_post_data_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_modify_post_data_plugin(self) -> None: + await self.protocol_handler._run_once() self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) @@ -178,32 +178,61 @@ def send(raw: bytes) -> int: ) self.assertFalse(self.protocol_handler.work.has_buffer()) - def test_modify_post_data_plugin(self) -> None: + # original = b'{"key": "value"}' modified = b'{"key": "modified"}' self.client_ssl_connection.recv.return_value = build_http_request( b'POST', b'/', headers={ b'Host': b'uni.corn', - b'Content-Type': b'application/x-www-form-urlencoded', b'Content-Length': bytes_(len(original)), + b'Content-Type': b'application/x-www-form-urlencoded', }, body=original, ) - self.protocol_handler._run_once() - self.server.queue.assert_called_with( - build_http_request( - b'POST', b'/', - headers={ - b'Host': b'uni.corn', - b'Content-Length': bytes_(len(modified)), - b'Content-Type': b'application/json', - }, - body=modified, - ), + await self.protocol_handler._run_once() + self.server.queue.assert_called_once() + # pkt = build_http_request( + # b'POST', b'/', + # headers={ + # b'Host': b'uni.corn', + # b'Content-Length': bytes_(len(modified)), + # b'Content-Type': b'application/json', + # }, + # body=modified, + # ) + response = HttpParser.response( + self.server.queue.call_args_list[0][0][0].tobytes(), ) + self.assertEqual(response.body, modified) + + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + '_setUp', + ( + ('test_man_in_the_middle_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_man_in_the_middle_plugin(self) -> None: + await self.protocol_handler._run_once() + + self.assertEqual(self.mock_sign_csr.call_count, 1) + self.assertEqual(self.mock_gen_csr.call_count, 1) + self.assertEqual(self.mock_gen_public_key.call_count, 1) - def test_man_in_the_middle_plugin(self) -> None: + self.mock_server_conn.assert_called_once_with('uni.corn', 443) + self.server.connect.assert_called() + self.assertEqual( + self.protocol_handler.work.connection, + self.client_ssl_connection, + ) + self.assertEqual(self.server.connection, self.server_ssl_connection) + self._conn.send.assert_called_with( + HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) + self.assertFalse(self.protocol_handler.work.has_buffer()) + # request = build_http_request( b'GET', b'/', headers={ @@ -213,20 +242,21 @@ def test_man_in_the_middle_plugin(self) -> None: self.client_ssl_connection.recv.return_value = request # Client read - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.server.queue.assert_called_once_with(request) # Server write - self.protocol_handler._run_once() + await self.protocol_handler._run_once() self.server.flush.assert_called_once() # Server read - self.server.recv.return_value = \ + self.server.recv.return_value = memoryview( build_http_response( httpStatusCodes.OK, reason=b'OK', body=b'Original Response From Upstream', - ) - self.protocol_handler._run_once() + ), + ) + await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), build_http_response( diff --git a/tests/test_assertions.py b/tests/test_assertions.py new file mode 100644 index 0000000000..85d6ca6f3f --- /dev/null +++ b/tests/test_assertions.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Any + + +class Assertions: + + def assertTrue(self, obj: Any) -> None: + assert obj + + def assertFalse(self, obj: Any) -> None: + assert not obj + + def assertEqual(self, obj1: Any, obj2: Any) -> None: + assert obj1 == obj2 + + def assertNotEqual(self, obj1: Any, obj2: Any) -> None: + assert obj1 != obj2 diff --git a/tests/test_main.py b/tests/test_main.py index 1f83153763..28874e822f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -21,7 +21,7 @@ from proxy.proxy import main, entry_point from proxy.common.utils import bytes_ -from proxy.common.constants import DEFAULT_ENABLE_DASHBOARD, DEFAULT_LOG_LEVEL, DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT +from proxy.common.constants import DEFAULT_ENABLE_DASHBOARD, DEFAULT_LOCAL_EXECUTOR, DEFAULT_LOG_LEVEL, DEFAULT_LOG_FILE from proxy.common.constants import DEFAULT_TIMEOUT, DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HTTP_PROXY from proxy.common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_EVENTS, DEFAULT_ENABLE_DEVTOOLS from proxy.common.constants import DEFAULT_ENABLE_WEB_SERVER, DEFAULT_THREADLESS, DEFAULT_CERT_FILE, DEFAULT_KEY_FILE @@ -30,7 +30,7 @@ from proxy.common.constants import DEFAULT_NUM_WORKERS, DEFAULT_OPEN_FILE_LIMIT, DEFAULT_IPV6_HOSTNAME from proxy.common.constants import DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_WORK_KLASS from proxy.common.constants import PLUGIN_INSPECT_TRAFFIC, PLUGIN_DASHBOARD, PLUGIN_DEVTOOLS_PROTOCOL, PLUGIN_WEB_SERVER -from proxy.common.constants import PLUGIN_HTTP_PROXY, DEFAULT_NUM_ACCEPTORS, PLUGIN_PROXY_AUTH +from proxy.common.constants import PLUGIN_HTTP_PROXY, DEFAULT_NUM_ACCEPTORS, PLUGIN_PROXY_AUTH, DEFAULT_LOG_FORMAT class TestMain(unittest.TestCase): @@ -71,6 +71,7 @@ def mock_default_args(mock_args: mock.Mock) -> None: mock_args.enable_events = DEFAULT_ENABLE_EVENTS mock_args.enable_dashboard = DEFAULT_ENABLE_DASHBOARD mock_args.work_klass = DEFAULT_WORK_KLASS + mock_args.local_executor = DEFAULT_LOCAL_EXECUTOR @mock.patch('os.remove') @mock.patch('os.path.exists') @@ -95,6 +96,7 @@ def test_entry_point( ) -> None: pid_file = os.path.join(tempfile.gettempdir(), 'pid') mock_sleep.side_effect = KeyboardInterrupt() + mock_initialize.return_value.local_executor = False mock_initialize.return_value.enable_events = False mock_initialize.return_value.pid_file = pid_file entry_point() @@ -142,6 +144,7 @@ def test_main_with_no_flags( mock_sleep: mock.Mock, ) -> None: mock_sleep.side_effect = KeyboardInterrupt() + mock_initialize.return_value.local_executor = False mock_initialize.return_value.enable_events = False main() mock_event_manager.assert_not_called() @@ -181,6 +184,7 @@ def test_enable_events( mock_sleep: mock.Mock, ) -> None: mock_sleep.side_effect = KeyboardInterrupt() + mock_initialize.return_value.local_executor = False mock_initialize.return_value.enable_events = True main() mock_event_manager.assert_called_once() diff --git a/tox.ini b/tox.ini index 5e315c8241..daa25f9bb3 100644 --- a/tox.ini +++ b/tox.ini @@ -228,6 +228,7 @@ deps = pre-commit pylint >= 2.5.3 pylint-pytest < 1.1.0 + pytest-mock >= 3.6.1 -r docs/requirements.in -r requirements-tunnel.txt isolated_build = true