Skip to content

Commit

Permalink
Merge pull request #36 from kmaork/work
Browse files Browse the repository at this point in the history
* Make the debugger more singletonic
* Fix small bug in piping and add option for many to many
  • Loading branch information
kmaork authored Feb 24, 2022
2 parents f811a1f + db8714e commit dee3137
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 26 deletions.
2 changes: 1 addition & 1 deletion madbg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ def connect_to_debugger(ip=DEFAULT_IP, port=DEFAULT_PORT, timeout=DEFAULT_CONNEC
send_message(socket, term_data)
with prepare_terminal():
socket_fd = socket.fileno()
Piping({in_fd: socket_fd, socket_fd: out_fd}).run()
Piping({in_fd: {socket_fd}, socket_fd: {out_fd}}).run()
tcdrain(out_fd)
28 changes: 16 additions & 12 deletions madbg/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from asyncio import new_event_loop
from io import BytesIO
from typing import Dict, Set

from .utils import opposite_dict

Expand All @@ -30,12 +31,13 @@ def blocking_read(fd, n):


class Piping:
def __init__(self, pipe_dict):
def __init__(self, pipe_dict: Dict[int, Set[int]]):
self.buffers = defaultdict(bytes)
self.loop = new_event_loop()
for src_fd, dest_fd in pipe_dict.items():
self.loop.add_reader(src_fd, partial(self._read, src_fd, dest_fd))
self.loop.add_writer(dest_fd, partial(self._write, dest_fd))
for src_fd, dest_fds in pipe_dict.items():
self.loop.add_reader(src_fd, partial(self._read, src_fd, dest_fds))
for dest_fd in dest_fds:
self.loop.add_writer(dest_fd, partial(self._write, dest_fd))
self.readers_to_writers = dict(pipe_dict)
self.writers_to_readers = opposite_dict(pipe_dict)

Expand All @@ -47,19 +49,21 @@ def _remove_writer(self, writer_fd):
def _remove_reader(self, reader_fd):
# remove all writers that im the last to write to, remove all that write to me, if nothing left stop loop
self.loop.remove_reader(reader_fd)
writer_fd = self.readers_to_writers.pop(reader_fd)
writer_readers = self.writers_to_readers[writer_fd]
writer_readers.remove(reader_fd)
if not writer_fd:
self._remove_writer(writer_fd)

def _read(self, src_fd, dest_fd):
writer_fds = self.readers_to_writers.pop(reader_fd)
for writer_fd in writer_fds:
writer_readers = self.writers_to_readers[writer_fd]
writer_readers.remove(reader_fd)
if not writer_readers:
self._remove_writer(writer_fd)

def _read(self, src_fd, dest_fds):
try:
data = os.read(src_fd, 1024)
except OSError:
data = ''
if data:
self.buffers[dest_fd] += data
for dest_fd in dest_fds:
self.buffers[dest_fd] += data
else:
self._remove_reader(src_fd)
if src_fd in self.writers_to_readers:
Expand Down
41 changes: 31 additions & 10 deletions madbg/debugger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import runpy
import os
import socket
Expand All @@ -6,6 +7,7 @@
from bdb import BdbQuit
from contextlib import contextmanager, nullcontext
from termios import tcdrain
from typing import Optional, ContextManager

from IPython.terminal.debugger import TerminalPdb
from IPython.terminal.interactiveshell import TerminalInteractiveShell
Expand All @@ -26,7 +28,15 @@ class RemoteIPythonDebugger(TerminalPdb):
Because we need to provide the stdin and stdout params to the __init__, and they require a connection to the client,
"""
_DEBUGGING_GLOBAL = 'DEBUGGING_WITH_MADBG'
# TODO: should this be a per-thread singleton? Because sys.settrace is singletonic
_CURRENT_INSTANCE = None

@classmethod
def _get_current_instance(cls) -> Optional[RemoteIPythonDebugger]:
return cls._CURRENT_INSTANCE

@classmethod
def _set_current_instance(cls, new: Optional[RemoteIPythonDebugger]) -> None:
cls._CURRENT_INSTANCE = new

def __init__(self, stdin, stdout, term_type):
# A patch until https://github.com/ipython/ipython/issues/11745 is solved
Expand Down Expand Up @@ -54,8 +64,10 @@ def trace_dispatch(self, frame, event, arg, check_debugging_global=False, done_c
except BdbQuit:
bdb_quit = True
finally:
if (done_callback is not None) and (self.quitting or bdb_quit):
done_callback()
if self.quitting or bdb_quit:
# To debugger finalization
if done_callback is not None:
done_callback()

def set_trace(self, frame=None, done_callback=None):
""" Overriding super to add the done_callback argument, allowing cleanup after a debug session """
Expand Down Expand Up @@ -96,7 +108,7 @@ def run_py(self, python_file, run_as_module, argv, set_trace=False):
runpy.run_path(python_file, run_name=run_name, init_globals=globals)

@contextmanager
def debug(self, check_debugging_global=False):
def debug(self, check_debugging_global=False) -> ContextManager:
self.reset()
sys.settrace(lambda *args: self.trace_dispatch(*args, check_debugging_global=check_debugging_global))
try:
Expand All @@ -109,30 +121,35 @@ def debug(self, check_debugging_global=False):

@classmethod
@contextmanager
def start(cls, sock_fd):
def start(cls, sock_fd: int) -> ContextManager[RemoteIPythonDebugger]:
# TODO: just add to pipe list
assert cls._get_current_instance() is None
term_data = receive_message(sock_fd)
term_attrs, term_type, term_size = term_data['term_attrs'], term_data['term_type'], term_data['term_size']
with PTY.open() as pty:
pty.resize(term_size[0], term_size[1])
pty.set_tty_attrs(term_attrs)
pty.make_ctty()
piping = Piping({sock_fd: pty.master_fd, pty.master_fd: sock_fd})
piping = Piping({sock_fd: {pty.master_fd}, pty.master_fd: {sock_fd}})
with run_thread(piping.run):
slave_reader = os.fdopen(pty.slave_fd, 'r')
slave_writer = os.fdopen(pty.slave_fd, 'w')
try:
yield cls(slave_reader, slave_writer, term_type)
instance = cls(slave_reader, slave_writer, term_type)
cls._set_current_instance(instance)
yield instance
except Exception:
print(traceback.format_exc(), file=slave_writer)
raise
finally:
cls._set_current_instance(None)
print('Closing connection', file=slave_writer, flush=True)
tcdrain(pty.slave_fd)
slave_writer.close()

@classmethod
@contextmanager
def get_server_socket(cls, ip: str, port: int) -> socket.socket:
def get_server_socket(cls, ip: str, port: int) -> ContextManager[socket.socket]:
"""
Return a new server socket for client to connect to. The caller is responsible for closing it.
"""
Expand All @@ -146,7 +163,7 @@ def get_server_socket(cls, ip: str, port: int) -> socket.socket:

@classmethod
@contextmanager
def start_from_new_connection(cls, sock: socket.socket):
def start_from_new_connection(cls, sock: socket.socket) -> ContextManager[RemoteIPythonDebugger]:
print_to_ctty(f'Debugger client connected from {sock.getpeername()}')
try:
with cls.start(sock.fileno()) as debugger:
Expand All @@ -155,7 +172,11 @@ def start_from_new_connection(cls, sock: socket.socket):
sock.close()

@classmethod
def connect_and_start(cls, ip, port):
def connect_and_start(cls, ip: str, port: int) -> ContextManager[RemoteIPythonDebugger]:
# TODO: get rid of context managers at some level - nobody is going to use with start() anyway
current_instance = cls._get_current_instance()
if current_instance is not None:
return nullcontext(current_instance)
with cls.get_server_socket(ip, port) as server_socket:
server_socket.listen(1)
print_to_ctty(f'Waiting for debugger client on {ip}:{port}')
Expand Down
8 changes: 5 additions & 3 deletions madbg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import contextmanager, ExitStack
from typing import Dict, Any, Set


@contextmanager
Expand Down Expand Up @@ -44,8 +45,9 @@ def run_thread(func, *args, **kwargs):
future.result()


def opposite_dict(dict_):
def opposite_dict(dict_: Dict[Any, Set[Any]]) -> Dict[Any, Set[Any]]:
opposite = defaultdict(set)
for key, value in dict_.items():
opposite[value].add(key)
for key, values in dict_.items():
for value in values:
opposite[value].add(key)
return opposite
6 changes: 6 additions & 0 deletions tests/system/test_set_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def test_set_trace_and_connect_twice(port, start_debugger_with_ctty):
debugger_future.result(JOIN_TIMEOUT)


def test_set_trace_twice_and_continue(port, start_debugger_with_ctty):
debugger_future = run_script_in_process(set_trace_script, start_debugger_with_ctty, port, 2)
assert b'Closing connection' in run_in_process(run_client, port, b'c\nq\n').result(JOIN_TIMEOUT)
debugger_future.result(JOIN_TIMEOUT)


def test_set_trace_and_quit_debugger(port, start_debugger_with_ctty):
debugger_future = run_script_in_process(set_trace_script, start_debugger_with_ctty, port)
client_future = run_in_process(run_client, port, b'q\n')
Expand Down

0 comments on commit dee3137

Please sign in to comment.