Skip to content

Commit

Permalink
Updated CoroutineClass pattern, and removed callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
synchronizing committed Jan 5, 2022
1 parent 03b0df5 commit 276cd10
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions mitm/mitm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import asyncio
import logging
import ssl
from typing import List, Callable
from typing import List

import toolbox
from toolbox.asyncio.pattern import CoroutineClass

from . import __data__, crypto, middleware, protocol
from .core import Connection, Flow, Host
Expand All @@ -16,7 +16,7 @@
logging.getLogger("asyncio").setLevel(logging.CRITICAL)


class MITM(toolbox.ClassTask):
class MITM(CoroutineClass):
"""
Man-in-the-middle server.
"""
Expand All @@ -30,7 +30,7 @@ def __init__(
buffer_size: int = 8192,
timeout: int = 5,
ssl_context: ssl.SSLContext = crypto.mitm_ssl_default_context(),
start: bool = False,
run: bool = False,
):
"""
Initializes the MITM class.
Expand All @@ -43,7 +43,7 @@ def __init__(
buffer_size: Buffer size to use. Defaults to `8192`.
timeout: Timeout to use. Defaults to `5`.
ssl_context: SSL context to use. Defaults to `crypto.mitm_ssl_default_context()`.
start: Whether to start the server immediately. Defaults to `False`.
run: Whether to start the server immediately. Defaults to `False`.
Example:
Expand All @@ -52,7 +52,7 @@ def __init__(
from mitm import MITM
mitm = MITM()
mitm.start()
mitm.run()
"""
self.host = host
self.port = port
Expand All @@ -61,14 +61,9 @@ def __init__(
self.buffer_size = buffer_size
self.timeout = timeout
self.ssl_context = ssl_context
super().__init__(run=run)

super().__init__(
func=lambda: self._run(callback=lambda: self._loop.stop()),
run_forever=True,
start=start,
)

async def _run(self, callback: Callable):
async def entry(self):
"""
Runs the MITM server.
"""
Expand All @@ -85,7 +80,7 @@ async def _run(self, callback: Callable):
port=self.port,
)
except OSError as e:
callback()
self._loop.stop()
raise e

for mw in self.middlewares:
Expand Down Expand Up @@ -120,15 +115,16 @@ async def _relay(connection: Connection, event: asyncio.Event, flow: Flow):
writer = connection.client.writer

while not event.is_set() and not reader.at_eof():
data = None
try:
data = await asyncio.wait_for(
reader.read(self.buffer_size),
self.timeout,
)
except asyncio.exceptions.TimeoutError:
continue
pass

if data == b"":
if not data:
event.set()
break
else:
Expand Down

0 comments on commit 276cd10

Please sign in to comment.