diff --git a/madbg/client.py b/madbg/client.py index cb8006a..736504f 100644 --- a/madbg/client.py +++ b/madbg/client.py @@ -72,5 +72,5 @@ def connect_to_debugger(ip=DEFAULT_IP, port=DEFAULT_PORT, timeout=DEFAULT_CONNEC send_message(socket, term_data) with prepare_terminal(): socket_fd = socket.fileno() - Piping({in_fd: socket_fd, socket_fd: out_fd}).run() + Piping({in_fd: {socket_fd}, socket_fd: {out_fd}}).run() tcdrain(out_fd) diff --git a/madbg/communication.py b/madbg/communication.py index 79003c5..27973b3 100644 --- a/madbg/communication.py +++ b/madbg/communication.py @@ -6,6 +6,7 @@ from functools import partial from asyncio import new_event_loop from io import BytesIO +from typing import Dict, Set from .utils import opposite_dict @@ -30,12 +31,13 @@ def blocking_read(fd, n): class Piping: - def __init__(self, pipe_dict): + def __init__(self, pipe_dict: Dict[int, Set[int]]): self.buffers = defaultdict(bytes) self.loop = new_event_loop() - for src_fd, dest_fd in pipe_dict.items(): - self.loop.add_reader(src_fd, partial(self._read, src_fd, dest_fd)) - self.loop.add_writer(dest_fd, partial(self._write, dest_fd)) + for src_fd, dest_fds in pipe_dict.items(): + self.loop.add_reader(src_fd, partial(self._read, src_fd, dest_fds)) + for dest_fd in dest_fds: + self.loop.add_writer(dest_fd, partial(self._write, dest_fd)) self.readers_to_writers = dict(pipe_dict) self.writers_to_readers = opposite_dict(pipe_dict) @@ -47,19 +49,21 @@ def _remove_writer(self, writer_fd): def _remove_reader(self, reader_fd): # remove all writers that im the last to write to, remove all that write to me, if nothing left stop loop self.loop.remove_reader(reader_fd) - writer_fd = self.readers_to_writers.pop(reader_fd) - writer_readers = self.writers_to_readers[writer_fd] - writer_readers.remove(reader_fd) - if not writer_fd: - self._remove_writer(writer_fd) - - def _read(self, src_fd, dest_fd): + writer_fds = self.readers_to_writers.pop(reader_fd) + for writer_fd in writer_fds: + writer_readers = self.writers_to_readers[writer_fd] + writer_readers.remove(reader_fd) + if not writer_readers: + self._remove_writer(writer_fd) + + def _read(self, src_fd, dest_fds): try: data = os.read(src_fd, 1024) except OSError: data = '' if data: - self.buffers[dest_fd] += data + for dest_fd in dest_fds: + self.buffers[dest_fd] += data else: self._remove_reader(src_fd) if src_fd in self.writers_to_readers: diff --git a/madbg/debugger.py b/madbg/debugger.py index eba931f..c04ed39 100644 --- a/madbg/debugger.py +++ b/madbg/debugger.py @@ -1,3 +1,4 @@ +from __future__ import annotations import runpy import os import socket @@ -6,6 +7,7 @@ from bdb import BdbQuit from contextlib import contextmanager, nullcontext from termios import tcdrain +from typing import Optional, ContextManager from IPython.terminal.debugger import TerminalPdb from IPython.terminal.interactiveshell import TerminalInteractiveShell @@ -26,7 +28,15 @@ class RemoteIPythonDebugger(TerminalPdb): Because we need to provide the stdin and stdout params to the __init__, and they require a connection to the client, """ _DEBUGGING_GLOBAL = 'DEBUGGING_WITH_MADBG' - # TODO: should this be a per-thread singleton? Because sys.settrace is singletonic + _CURRENT_INSTANCE = None + + @classmethod + def _get_current_instance(cls) -> Optional[RemoteIPythonDebugger]: + return cls._CURRENT_INSTANCE + + @classmethod + def _set_current_instance(cls, new: Optional[RemoteIPythonDebugger]) -> None: + cls._CURRENT_INSTANCE = new def __init__(self, stdin, stdout, term_type): # A patch until https://github.com/ipython/ipython/issues/11745 is solved @@ -54,8 +64,10 @@ def trace_dispatch(self, frame, event, arg, check_debugging_global=False, done_c except BdbQuit: bdb_quit = True finally: - if (done_callback is not None) and (self.quitting or bdb_quit): - done_callback() + if self.quitting or bdb_quit: + # To debugger finalization + if done_callback is not None: + done_callback() def set_trace(self, frame=None, done_callback=None): """ Overriding super to add the done_callback argument, allowing cleanup after a debug session """ @@ -96,7 +108,7 @@ def run_py(self, python_file, run_as_module, argv, set_trace=False): runpy.run_path(python_file, run_name=run_name, init_globals=globals) @contextmanager - def debug(self, check_debugging_global=False): + def debug(self, check_debugging_global=False) -> ContextManager: self.reset() sys.settrace(lambda *args: self.trace_dispatch(*args, check_debugging_global=check_debugging_global)) try: @@ -109,30 +121,35 @@ def debug(self, check_debugging_global=False): @classmethod @contextmanager - def start(cls, sock_fd): + def start(cls, sock_fd: int) -> ContextManager[RemoteIPythonDebugger]: + # TODO: just add to pipe list + assert cls._get_current_instance() is None term_data = receive_message(sock_fd) term_attrs, term_type, term_size = term_data['term_attrs'], term_data['term_type'], term_data['term_size'] with PTY.open() as pty: pty.resize(term_size[0], term_size[1]) pty.set_tty_attrs(term_attrs) pty.make_ctty() - piping = Piping({sock_fd: pty.master_fd, pty.master_fd: sock_fd}) + piping = Piping({sock_fd: {pty.master_fd}, pty.master_fd: {sock_fd}}) with run_thread(piping.run): slave_reader = os.fdopen(pty.slave_fd, 'r') slave_writer = os.fdopen(pty.slave_fd, 'w') try: - yield cls(slave_reader, slave_writer, term_type) + instance = cls(slave_reader, slave_writer, term_type) + cls._set_current_instance(instance) + yield instance except Exception: print(traceback.format_exc(), file=slave_writer) raise finally: + cls._set_current_instance(None) print('Closing connection', file=slave_writer, flush=True) tcdrain(pty.slave_fd) slave_writer.close() @classmethod @contextmanager - def get_server_socket(cls, ip: str, port: int) -> socket.socket: + def get_server_socket(cls, ip: str, port: int) -> ContextManager[socket.socket]: """ Return a new server socket for client to connect to. The caller is responsible for closing it. """ @@ -146,7 +163,7 @@ def get_server_socket(cls, ip: str, port: int) -> socket.socket: @classmethod @contextmanager - def start_from_new_connection(cls, sock: socket.socket): + def start_from_new_connection(cls, sock: socket.socket) -> ContextManager[RemoteIPythonDebugger]: print_to_ctty(f'Debugger client connected from {sock.getpeername()}') try: with cls.start(sock.fileno()) as debugger: @@ -155,7 +172,11 @@ def start_from_new_connection(cls, sock: socket.socket): sock.close() @classmethod - def connect_and_start(cls, ip, port): + def connect_and_start(cls, ip: str, port: int) -> ContextManager[RemoteIPythonDebugger]: + # TODO: get rid of context managers at some level - nobody is going to use with start() anyway + current_instance = cls._get_current_instance() + if current_instance is not None: + return nullcontext(current_instance) with cls.get_server_socket(ip, port) as server_socket: server_socket.listen(1) print_to_ctty(f'Waiting for debugger client on {ip}:{port}') diff --git a/madbg/utils.py b/madbg/utils.py index f164811..2f0afd4 100644 --- a/madbg/utils.py +++ b/madbg/utils.py @@ -4,6 +4,7 @@ from collections import defaultdict from concurrent.futures.thread import ThreadPoolExecutor from contextlib import contextmanager, ExitStack +from typing import Dict, Any, Set @contextmanager @@ -44,8 +45,9 @@ def run_thread(func, *args, **kwargs): future.result() -def opposite_dict(dict_): +def opposite_dict(dict_: Dict[Any, Set[Any]]) -> Dict[Any, Set[Any]]: opposite = defaultdict(set) - for key, value in dict_.items(): - opposite[value].add(key) + for key, values in dict_.items(): + for value in values: + opposite[value].add(key) return opposite diff --git a/tests/system/test_set_trace.py b/tests/system/test_set_trace.py index a1a217e..f2d47c1 100644 --- a/tests/system/test_set_trace.py +++ b/tests/system/test_set_trace.py @@ -36,6 +36,12 @@ def test_set_trace_and_connect_twice(port, start_debugger_with_ctty): debugger_future.result(JOIN_TIMEOUT) +def test_set_trace_twice_and_continue(port, start_debugger_with_ctty): + debugger_future = run_script_in_process(set_trace_script, start_debugger_with_ctty, port, 2) + assert b'Closing connection' in run_in_process(run_client, port, b'c\nq\n').result(JOIN_TIMEOUT) + debugger_future.result(JOIN_TIMEOUT) + + def test_set_trace_and_quit_debugger(port, start_debugger_with_ctty): debugger_future = run_script_in_process(set_trace_script, start_debugger_with_ctty, port) client_future = run_in_process(run_client, port, b'q\n')