From 276cd1078e0ae59631c10b720c6fb6e2934d43dd Mon Sep 17 00:00:00 2001 From: synchronizing <2829082+synchronizing@users.noreply.github.com> Date: Wed, 5 Jan 2022 00:27:50 -0500 Subject: [PATCH] Updated CoroutineClass pattern, and removed callback. --- mitm/mitm.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/mitm/mitm.py b/mitm/mitm.py index f8ddf45..8de2c5d 100644 --- a/mitm/mitm.py +++ b/mitm/mitm.py @@ -5,9 +5,9 @@ import asyncio import logging import ssl -from typing import List, Callable +from typing import List -import toolbox +from toolbox.asyncio.pattern import CoroutineClass from . import __data__, crypto, middleware, protocol from .core import Connection, Flow, Host @@ -16,7 +16,7 @@ logging.getLogger("asyncio").setLevel(logging.CRITICAL) -class MITM(toolbox.ClassTask): +class MITM(CoroutineClass): """ Man-in-the-middle server. """ @@ -30,7 +30,7 @@ def __init__( buffer_size: int = 8192, timeout: int = 5, ssl_context: ssl.SSLContext = crypto.mitm_ssl_default_context(), - start: bool = False, + run: bool = False, ): """ Initializes the MITM class. @@ -43,7 +43,7 @@ def __init__( buffer_size: Buffer size to use. Defaults to `8192`. timeout: Timeout to use. Defaults to `5`. ssl_context: SSL context to use. Defaults to `crypto.mitm_ssl_default_context()`. - start: Whether to start the server immediately. Defaults to `False`. + run: Whether to start the server immediately. Defaults to `False`. Example: @@ -52,7 +52,7 @@ def __init__( from mitm import MITM mitm = MITM() - mitm.start() + mitm.run() """ self.host = host self.port = port @@ -61,14 +61,9 @@ def __init__( self.buffer_size = buffer_size self.timeout = timeout self.ssl_context = ssl_context + super().__init__(run=run) - super().__init__( - func=lambda: self._run(callback=lambda: self._loop.stop()), - run_forever=True, - start=start, - ) - - async def _run(self, callback: Callable): + async def entry(self): """ Runs the MITM server. """ @@ -85,7 +80,7 @@ async def _run(self, callback: Callable): port=self.port, ) except OSError as e: - callback() + self._loop.stop() raise e for mw in self.middlewares: @@ -120,15 +115,16 @@ async def _relay(connection: Connection, event: asyncio.Event, flow: Flow): writer = connection.client.writer while not event.is_set() and not reader.at_eof(): + data = None try: data = await asyncio.wait_for( reader.read(self.buffer_size), self.timeout, ) except asyncio.exceptions.TimeoutError: - continue + pass - if data == b"": + if not data: event.set() break else: