Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Piping refactor + single tracing session #36

Merged
merged 2 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion madbg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ def connect_to_debugger(ip=DEFAULT_IP, port=DEFAULT_PORT, timeout=DEFAULT_CONNEC
send_message(socket, term_data)
with prepare_terminal():
socket_fd = socket.fileno()
Piping({in_fd: socket_fd, socket_fd: out_fd}).run()
Piping({in_fd: {socket_fd}, socket_fd: {out_fd}}).run()
tcdrain(out_fd)
28 changes: 16 additions & 12 deletions madbg/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from asyncio import new_event_loop
from io import BytesIO
from typing import Dict, Set

from .utils import opposite_dict

Expand All @@ -30,12 +31,13 @@ def blocking_read(fd, n):


class Piping:
def __init__(self, pipe_dict):
def __init__(self, pipe_dict: Dict[int, Set[int]]):
self.buffers = defaultdict(bytes)
self.loop = new_event_loop()
for src_fd, dest_fd in pipe_dict.items():
self.loop.add_reader(src_fd, partial(self._read, src_fd, dest_fd))
self.loop.add_writer(dest_fd, partial(self._write, dest_fd))
for src_fd, dest_fds in pipe_dict.items():
self.loop.add_reader(src_fd, partial(self._read, src_fd, dest_fds))
for dest_fd in dest_fds:
self.loop.add_writer(dest_fd, partial(self._write, dest_fd))
self.readers_to_writers = dict(pipe_dict)
self.writers_to_readers = opposite_dict(pipe_dict)

Expand All @@ -47,19 +49,21 @@ def _remove_writer(self, writer_fd):
def _remove_reader(self, reader_fd):
# remove all writers that im the last to write to, remove all that write to me, if nothing left stop loop
self.loop.remove_reader(reader_fd)
writer_fd = self.readers_to_writers.pop(reader_fd)
writer_readers = self.writers_to_readers[writer_fd]
writer_readers.remove(reader_fd)
if not writer_fd:
self._remove_writer(writer_fd)

def _read(self, src_fd, dest_fd):
writer_fds = self.readers_to_writers.pop(reader_fd)
for writer_fd in writer_fds:
writer_readers = self.writers_to_readers[writer_fd]
writer_readers.remove(reader_fd)
if not writer_readers:
self._remove_writer(writer_fd)

def _read(self, src_fd, dest_fds):
try:
data = os.read(src_fd, 1024)
except OSError:
data = ''
if data:
self.buffers[dest_fd] += data
for dest_fd in dest_fds:
self.buffers[dest_fd] += data
else:
self._remove_reader(src_fd)
if src_fd in self.writers_to_readers:
Expand Down
41 changes: 31 additions & 10 deletions madbg/debugger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import runpy
import os
import socket
Expand All @@ -6,6 +7,7 @@
from bdb import BdbQuit
from contextlib import contextmanager, nullcontext
from termios import tcdrain
from typing import Optional, ContextManager

from IPython.terminal.debugger import TerminalPdb
from IPython.terminal.interactiveshell import TerminalInteractiveShell
Expand All @@ -26,7 +28,15 @@ class RemoteIPythonDebugger(TerminalPdb):
Because we need to provide the stdin and stdout params to the __init__, and they require a connection to the client,
"""
_DEBUGGING_GLOBAL = 'DEBUGGING_WITH_MADBG'
# TODO: should this be a per-thread singleton? Because sys.settrace is singletonic
_CURRENT_INSTANCE = None

@classmethod
def _get_current_instance(cls) -> Optional[RemoteIPythonDebugger]:
return cls._CURRENT_INSTANCE

@classmethod
def _set_current_instance(cls, new: Optional[RemoteIPythonDebugger]) -> None:
cls._CURRENT_INSTANCE = new

def __init__(self, stdin, stdout, term_type):
# A patch until https://github.com/ipython/ipython/issues/11745 is solved
Expand Down Expand Up @@ -54,8 +64,10 @@ def trace_dispatch(self, frame, event, arg, check_debugging_global=False, done_c
except BdbQuit:
bdb_quit = True
finally:
if (done_callback is not None) and (self.quitting or bdb_quit):
done_callback()
if self.quitting or bdb_quit:
# To debugger finalization
if done_callback is not None:
done_callback()

def set_trace(self, frame=None, done_callback=None):
""" Overriding super to add the done_callback argument, allowing cleanup after a debug session """
Expand Down Expand Up @@ -96,7 +108,7 @@ def run_py(self, python_file, run_as_module, argv, set_trace=False):
runpy.run_path(python_file, run_name=run_name, init_globals=globals)

@contextmanager
def debug(self, check_debugging_global=False):
def debug(self, check_debugging_global=False) -> ContextManager:
self.reset()
sys.settrace(lambda *args: self.trace_dispatch(*args, check_debugging_global=check_debugging_global))
try:
Expand All @@ -109,30 +121,35 @@ def debug(self, check_debugging_global=False):

@classmethod
@contextmanager
def start(cls, sock_fd):
def start(cls, sock_fd: int) -> ContextManager[RemoteIPythonDebugger]:
# TODO: just add to pipe list
assert cls._get_current_instance() is None
term_data = receive_message(sock_fd)
term_attrs, term_type, term_size = term_data['term_attrs'], term_data['term_type'], term_data['term_size']
with PTY.open() as pty:
pty.resize(term_size[0], term_size[1])
pty.set_tty_attrs(term_attrs)
pty.make_ctty()
piping = Piping({sock_fd: pty.master_fd, pty.master_fd: sock_fd})
piping = Piping({sock_fd: {pty.master_fd}, pty.master_fd: {sock_fd}})
with run_thread(piping.run):
slave_reader = os.fdopen(pty.slave_fd, 'r')
slave_writer = os.fdopen(pty.slave_fd, 'w')
try:
yield cls(slave_reader, slave_writer, term_type)
instance = cls(slave_reader, slave_writer, term_type)
cls._set_current_instance(instance)
yield instance
except Exception:
print(traceback.format_exc(), file=slave_writer)
raise
finally:
cls._set_current_instance(None)
print('Closing connection', file=slave_writer, flush=True)
tcdrain(pty.slave_fd)
slave_writer.close()

@classmethod
@contextmanager
def get_server_socket(cls, ip: str, port: int) -> socket.socket:
def get_server_socket(cls, ip: str, port: int) -> ContextManager[socket.socket]:
"""
Return a new server socket for client to connect to. The caller is responsible for closing it.
"""
Expand All @@ -146,7 +163,7 @@ def get_server_socket(cls, ip: str, port: int) -> socket.socket:

@classmethod
@contextmanager
def start_from_new_connection(cls, sock: socket.socket):
def start_from_new_connection(cls, sock: socket.socket) -> ContextManager[RemoteIPythonDebugger]:
print_to_ctty(f'Debugger client connected from {sock.getpeername()}')
try:
with cls.start(sock.fileno()) as debugger:
Expand All @@ -155,7 +172,11 @@ def start_from_new_connection(cls, sock: socket.socket):
sock.close()

@classmethod
def connect_and_start(cls, ip, port):
def connect_and_start(cls, ip: str, port: int) -> ContextManager[RemoteIPythonDebugger]:
# TODO: get rid of context managers at some level - nobody is going to use with start() anyway
current_instance = cls._get_current_instance()
if current_instance is not None:
return nullcontext(current_instance)
with cls.get_server_socket(ip, port) as server_socket:
server_socket.listen(1)
print_to_ctty(f'Waiting for debugger client on {ip}:{port}')
Expand Down
8 changes: 5 additions & 3 deletions madbg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import contextmanager, ExitStack
from typing import Dict, Any, Set


@contextmanager
Expand Down Expand Up @@ -44,8 +45,9 @@ def run_thread(func, *args, **kwargs):
future.result()


def opposite_dict(dict_):
def opposite_dict(dict_: Dict[Any, Set[Any]]) -> Dict[Any, Set[Any]]:
opposite = defaultdict(set)
for key, value in dict_.items():
opposite[value].add(key)
for key, values in dict_.items():
for value in values:
opposite[value].add(key)
return opposite
6 changes: 6 additions & 0 deletions tests/system/test_set_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def test_set_trace_and_connect_twice(port, start_debugger_with_ctty):
debugger_future.result(JOIN_TIMEOUT)


def test_set_trace_twice_and_continue(port, start_debugger_with_ctty):
debugger_future = run_script_in_process(set_trace_script, start_debugger_with_ctty, port, 2)
assert b'Closing connection' in run_in_process(run_client, port, b'c\nq\n').result(JOIN_TIMEOUT)
debugger_future.result(JOIN_TIMEOUT)


def test_set_trace_and_quit_debugger(port, start_debugger_with_ctty):
debugger_future = run_script_in_process(set_trace_script, start_debugger_with_ctty, port)
client_future = run_in_process(run_client, port, b'q\n')
Expand Down