diff --git a/pyroute2/iproute/linux.py b/pyroute2/iproute/linux.py index 1336672d2..3225825e7 100644 --- a/pyroute2/iproute/linux.py +++ b/pyroute2/iproute/linux.py @@ -10,7 +10,6 @@ from pyroute2 import config from pyroute2.common import basestring from pyroute2.config import AF_BRIDGE -from pyroute2.lab import LAB_API from pyroute2.netlink import ( NLM_F_ACK, NLM_F_ATOMIC, @@ -20,6 +19,7 @@ NLM_F_ROOT, NLMSG_ERROR, ) +from pyroute2.netlink.core import SyncMixin from pyroute2.netlink.exceptions import ( NetlinkDumpInterrupted, NetlinkError, @@ -106,6 +106,16 @@ log = logging.getLogger(__name__) +def compat_get_dump_filter(kwarg): + if 'match' in kwarg: + return kwarg.pop('match'), kwarg + else: + new_kwarg = {} + if 'family' in kwarg: + new_kwarg['family'] = kwarg.pop('family') + return kwarg, new_kwarg + + def get_default_request_filters(mode, command): filters = { 'link': [LinkFieldFilter(), LinkIPRouteFilter(command)], @@ -1355,12 +1365,12 @@ def neigh(self, command, **kwarg): dump_filter = None msg = ndmsg.ndmsg() if command == 'dump': - dump_filter, kwarg = get_dump_filter(kwarg) + dump_filter, kwarg = compat_get_dump_filter(kwarg) request = ( RequestProcessor(context=kwarg, prime=kwarg) - .apply_filter(NeighbourFieldFilter()) - .apply_filter(NeighbourIPRouteFilter(command)) + .add_filter(NeighbourFieldFilter()) + .add_filter(NeighbourIPRouteFilter(command)) .finalize() ) msg_type, msg_flags = self.make_request_type(command, command_map) @@ -2492,14 +2502,56 @@ class IPBatch(RTNL_API, IPBatchSocket): pass -class IPRoute(LAB_API, RTNL_API, IPRSocket): +class AsyncIPRoute(RTNL_API, IPRSocket): ''' - Regular ordinary utility class, see RTNL API for the list of methods. + Regular ordinary async utility class, provides RTNL API using + IPRSocket as the transport level. ''' pass +class IPRoute(SyncMixin, AsyncIPRoute): + ''' + A synchronous version of AsyncIPRoute. All the same API, but + sync. Provides a legacy API for the old code that is not using + asyncio. + ''' + + def dump(self, groups=None): + groups_map = { + RTMGRP_LINK: [partial(self.link, 'dump')], + RTMGRP_IPV4_IFADDR: [partial(self.addr, 'dump', family=AF_INET)], + RTMGRP_IPV4_ROUTE: [partial(self.route, 'dump', family=AF_INET)], + } + for group, methods in groups_map.items(): + if group & (groups if groups is not None else self.groups): + for method in methods: + for msg in method(): + yield msg + + def __getattribute__(self, name): + async_methods = ['addr', 'link', 'route'] + symbol = super().__getattribute__(name) + + def synchronize(*argv, **kwarg): + async def collect_dump(): + return [i async for i in await symbol(*argv, **kwarg)] + + async def collect_op(): + return await symbol(*argv, **kwarg) + + if argv[0] == 'dump': + task = collect_dump + else: + task = collect_op + return self.event_loop.run_until_complete(task()) + + if name in async_methods: + return synchronize + return symbol + + class NetNS(IPRoute): def __init__( @@ -2511,11 +2563,7 @@ def __init__( groups=RTMGRP_DEFAULTS, ): super().__init__( - target=netns if netns is not None else target, - netns=netns, - flags=flags, - libc=libc, - groups=groups, + target=target, netns=netns, flags=flags, libc=libc, groups=groups ) diff --git a/pyroute2/ndb/main.py b/pyroute2/ndb/main.py index 81b65238f..6617778f4 100644 --- a/pyroute2/ndb/main.py +++ b/pyroute2/ndb/main.py @@ -284,7 +284,6 @@ def add_mock_netns(self, netns): import ctypes.util import logging import logging.handlers -import sys import threading from pyroute2 import config @@ -511,10 +510,6 @@ def __init__( 'nlm_generator': 1, } ] - if sys.platform.startswith('linux'): - sources.append( - {'target': self.nsmanager, 'kind': 'nsmanager'} - ) elif not isinstance(sources, (list, tuple)): raise ValueError('sources format not supported') diff --git a/pyroute2/ndb/objects/__init__.py b/pyroute2/ndb/objects/__init__.py index 4d9c8e62e..aafc4a2e3 100644 --- a/pyroute2/ndb/objects/__init__.py +++ b/pyroute2/ndb/objects/__init__.py @@ -295,9 +295,8 @@ def __init__( self.load_event.set() self.load_debug = False self.lock = threading.Lock() - self.object_data = RequestProcessor( - self.field_filter(), context=weakref.proxy(self) - ) + self.object_data = RequestProcessor(context=weakref.proxy(self)) + self.object_data.add_filter(self.field_filter()) self.kspec = self.schema.compiled[self.table]['idx'] self.knorm = self.schema.compiled[self.table]['norm_idx'] self.spec = self.schema.compiled[self.table]['all_names'] @@ -351,7 +350,8 @@ def __init__( def new_spec(cls, spec, context=None, localhost=None): if isinstance(spec, Record): spec = spec._as_dict() - rp = RequestProcessor(cls.field_filter(), context=spec, prime=spec) + rp = RequestProcessor(context=spec, prime=spec) + rp.add_filter(cls.field_filter()) if isinstance(context, dict): rp.update(context) if 'target' not in rp and localhost is not None: diff --git a/pyroute2/ndb/source.py b/pyroute2/ndb/source.py index 0c420ac97..f4b937d10 100644 --- a/pyroute2/ndb/source.py +++ b/pyroute2/ndb/source.py @@ -395,6 +395,7 @@ def receiver(self): while self.state.get() not in ('stop', 'restart'): try: msg = tuple(self.nl.get()) + self.log.debug(f'received message {msg}') except Exception as e: self.errors_counter += 1 self.log.error('source error: %s %s' % (type(e), e)) diff --git a/pyroute2/netlink/core.py b/pyroute2/netlink/core.py index 538d26af0..20660e8cb 100644 --- a/pyroute2/netlink/core.py +++ b/pyroute2/netlink/core.py @@ -5,25 +5,35 @@ import multiprocessing import os import socket +import struct +import threading from urllib import parse from pyroute2 import config from pyroute2.common import AddrPool -from pyroute2.netlink import NLM_F_MULTI, NLMSG_DONE +from pyroute2.netlink import NLM_F_DUMP, NLM_F_MULTI, NLM_F_REQUEST, NLMSG_DONE from pyroute2.netns import setns from pyroute2.requests.main import RequestProcessor log = logging.getLogger(__name__) Stats = collections.namedtuple('Stats', ('qsize', 'delta', 'delay')) +CoreSocketResources = collections.namedtuple( + 'CoreSocketResources', + ('socket', 'msg_queue', 'event_loop', 'transport', 'protocol'), +) class CoreSocketSpec(dict): + defaults = {'closed': False, 'compiled': None, 'uname': config.uname} + status_filters = [] + def __init__(self, spec=None): super().__init__(spec) spec = {} if spec is None else spec - default = {'closed': False, 'compiled': None, 'uname': config.uname} self.status = RequestProcessor() - self.status.update(default) + for flt in self.status_filters: + self.status.add_filter(flt()) + self.status.update(self.defaults) self.status.update(self) def __setitem__(self, key, value): @@ -82,6 +92,9 @@ def connection_made(self, transport): def connection_lost(self, exc): self.on_con_lost.set_result(True) + self.enqueue( + struct.pack('IHHQIQQ', 28, 2, 0, 0, errno.ECONNRESET, 0, 0), None + ) class CoreStreamProtocol(CoreProtocol): @@ -98,26 +111,35 @@ def datagram_received(self, data, addr): self.enqueue(data, addr) -def netns_init(ctl, nsname, cls): +async def netns_main(ctl, nsname, cls): + # A simple child process + # + # 1. set network namespace setns(nsname) + # 2. start the socket object s = cls() - print(" <<< ", s) + await s.ensure_socket() + # 3. send back the file descriptor socket.send_fds(ctl, [b'test'], [s.socket.fileno()], 1) - print(" done ") + # 4. exit + + +def netns_init(ctl, nsname, cls): + asyncio.run(netns_main(ctl, nsname, cls)) -class CoreSocket: +class AsyncCoreSocket: '''Pyroute2 core socket class. This class implements the core socket concept for all the pyroute2 communications, both Netlink and internal RPC. + + The asynchronous version is the most basic. All the sync classes + are built on top of it. ''' libc = None - socket = None compiled = None - endpoint = None - event_loop = None __spec = None __marshal = None @@ -142,46 +164,19 @@ def __init__( 'groups': groups, } ) + self.status = self.spec.status + self.local = threading.local() if libc is not None: self.libc = libc - self.status = self.spec.status - url = parse.urlparse(self.status['target']) + url = parse.urlparse(self.spec['target']) self.scheme = url.scheme if url.scheme else url.path self.use_socket = use_socket - # 8<----------------------------------------- - # Setup netns - if self.spec['netns'] is not None: - # inspect self.__init__ argument names - ctrl = socket.socketpair() - nsproc = multiprocessing.Process( - target=netns_init, - args=(ctrl[0], self.spec['netns'], type(self)), - ) - nsproc.start() - (_, (self.spec['fileno'],), _, _) = socket.recv_fds( - ctrl[1], 1024, 1 - ) - nsproc.join() - # 8<----------------------------------------- self.callbacks = [] # [(predicate, callback, args), ...] self.addr_pool = AddrPool(minaddr=0x000000FF, maxaddr=0x0000FFFF) self.marshal = None self.buffer = [] self.msg_reschedule = [] - # 8<----------------------------------------- - # Setup the underlying socket - self.socket = self.setup_socket() - self.msg_queue = CoreMessageQueue() - self.event_loop = self.setup_event_loop() - self.connection_lost = self.event_loop.create_future() - if self.event_loop.is_running(): - self.endpoint_started = asyncio.ensure_future( - self.setup_endpoint() - ) - else: - self.event_loop.run_until_complete(self.setup_endpoint()) - self.endpoint_started = self.event_loop.create_future() - self.endpoint_started.set_result(True) + self.__all_open_resources = set() def get_loop(self): return self.event_loop @@ -204,6 +199,89 @@ def marshal(self, value): if self.__marshal is None: self.__marshal = value + # 8<-------------------------------------------------------------- + # Thread local section + @property + def msg_queue(self): + return self.local.msg_queue + + @property + def port(self): + if not hasattr(self.local, 'port'): + import random + + self.local.port = random.randint(20, 200) + return self.status['pid'] + (self.local.port << 22) + + @property + def connection_lost(self): + return self.local.connection_lost + + @property + def event_loop(self): + if not hasattr(self.local, 'event_loop'): + self.local.event_loop = self.setup_event_loop() + self.local.connection_lost = self.local.event_loop.create_future() + return self.local.event_loop + + async def ensure_socket(self): + if not hasattr(self.local, 'socket'): + self.local.socket = None + self.local.fileno = None + self.local.msg_queue = CoreMessageQueue() + # 8<----------------------------------------- + # Setup netns + if self.spec['netns'] is not None: + # inspect self.__init__ argument names + ctrl = socket.socketpair() + nsproc = multiprocessing.Process( + target=netns_init, + args=(ctrl[0], self.spec['netns'], type(self)), + ) + nsproc.start() + (_, (self.local.fileno,), _, _) = socket.recv_fds( + ctrl[1], 1024, 1 + ) + nsproc.join() + # 8<----------------------------------------- + self.local.socket = await self.setup_socket() + self.endpoint_started = await self.setup_endpoint() + self.__all_open_resources.add( + CoreSocketResources( + self.local.socket, + self.local.msg_queue, + self.local.event_loop, + self.local.endpoint[0], + self.local.endpoint[1], + ) + ) + + @property + def socket(self): + return self.local.socket + + @property + def endpoint_started(self): + if not hasattr(self.local, 'endpoint_started'): + self.local.endpoint_started = False + return self.local.endpoint_started + + @property + def endpoint(self): + if not hasattr(self.local, 'endpoint'): + self.local.endpoint = None + return self.local.endpoint + + @endpoint_started.setter + def endpoint_started(self, value): + self.local.endpoint_started = value + + @endpoint.setter + def endpoint(self, value): + self.local.endpoint = value + + # 8<-------------------------------------------------------------- + async def setup_endpoint(self, loop=None): # Setup asyncio if self.endpoint is not None: @@ -223,7 +301,7 @@ def setup_event_loop(self, event_loop=None): self.status['event_loop'] = 'new' return event_loop - def setup_socket(self, sock=None): + async def setup_socket(self, sock=None): if self.status['use_socket']: return self.use_socket sock = self.socket if sock is None else sock @@ -255,13 +333,27 @@ def __getattr__(self, attr): return getattr(self.socket, attr.lstrip("_")) raise AttributeError(attr) - def bind(self, addr): + async def bind(self, addr): '''Bind the socket to the address.''' + await self.ensure_socket() return self.socket.bind(addr) - def close(self, code=errno.ECONNRESET): - '''Correctly close the socket and free all the resources.''' - self.socket.close() + async def close(self, code=errno.ECONNRESET): + '''Terminate the object.''' + + def send_terminator(msg_queue): + msg_queue.put_nowait(0, b'') + + for ( + socket, + msg_queue, + event_loop, + transport, + protocol, + ) in self.__all_open_resources: + event_loop.call_soon_threadsafe(send_terminator, msg_queue) + transport.close() + socket.close() def clone(self): '''Return a copy of itself with a new underlying socket.''' @@ -293,41 +385,43 @@ def connect(self, address): def enqueue(self, data, addr): return self.msg_queue.put_nowait(0, data) - def get(self, msg_seq=0, terminate=None, callback=None, noraise=False): - '''Sync wrapper for async_get().''' - - async def collect_data(): - return [ - i - async for i in self.async_get( - msg_seq, terminate, callback, noraise - ) - ] - - return self.event_loop.run_until_complete(collect_data()) - - async def async_get( + async def get( self, msg_seq=0, terminate=None, callback=None, noraise=False ): '''Get a conversation answer from the socket.''' + await self.ensure_socket() log.debug( "get: %s / %s / %s / %s", msg_seq, terminate, callback, noraise ) enough = False started = False + error = None while not enough: + log.debug('await data on %s', self.msg_queue) data = await self.msg_queue.get(msg_seq) + # try: + # task = asyncio.wait_for( + # self.msg_queue.get(msg_seq), timeout=None + # ) + # data = await task + # except TimeoutError: + # continue messages = tuple(self.marshal.parse(data, msg_seq, callback)) if len(messages) == 0: break for msg in messages: - if started and msg['header']['type'] == NLMSG_DONE: - return + log.debug("message %s", msg) + if msg.get('header', {}).get('error') is not None: + error = msg['header']['error'] + enough = True + break + if msg['header']['type'] == NLMSG_DONE: + enough = True + break msg['header']['target'] = self.status['target'] msg['header']['stats'] = Stats(0, 0, 0) started = True log.debug("yield %s", msg['header']) - log.debug("message %s", msg) yield msg if started and ( @@ -335,7 +429,10 @@ async def async_get( or (not msg['header'].get('flags', 0) & NLM_F_MULTI) or (callable(terminate) and terminate(msg)) ): + log.debug("D") enough = True + if not noraise and error: + raise error def __enter__(self): return self @@ -467,3 +564,58 @@ def get_policy_map(self, policy=None): ret[key] = self.marshal.msg_map[key] return ret + + +class SyncMixin: + ''' + Synchronous API wrapper around asynchronous classes + ''' + + @property + def asyncore(self): + return super() + + def bind(self, *argv, **kwarg): + return self.event_loop.run_until_complete( + self.asyncore.bind(*argv, **kwarg) + ) + + def close(self, code=errno.ECONNRESET): + '''Correctly close the socket and free all the resources.''' + return self.event_loop.run_until_complete(self.asyncore.close(code)) + + def nlm_request( + self, + msg, + msg_type, + msg_flags=NLM_F_REQUEST | NLM_F_DUMP, + terminate=None, + callback=None, + parser=None, + ): + async def collect_data(): + return [ + x + async for x in self.asyncore.nlm_request( + msg, msg_type, msg_flags, terminate, callback, parser + ) + ] + + return self.event_loop.run_until_complete(collect_data()) + + def get(self, msg_seq=0, terminate=None, callback=None, noraise=False): + '''Sync wrapper for async_get().''' + + async def collect_data(): + return [ + i + async for i in self.asyncore.get( + msg_seq, terminate, callback, noraise + ) + ] + + return self.event_loop.run_until_complete(collect_data()) + + +class CoreSocket(AsyncCoreSocket, SyncMixin): + pass diff --git a/pyroute2/netlink/nlsocket.py b/pyroute2/netlink/nlsocket.py index 80ef012dc..bea060180 100644 --- a/pyroute2/netlink/nlsocket.py +++ b/pyroute2/netlink/nlsocket.py @@ -104,24 +104,19 @@ NLM_F_APPEND, NLM_F_CREATE, NLM_F_DUMP, - NLM_F_DUMP_INTR, NLM_F_ECHO, NLM_F_EXCL, NLM_F_REPLACE, NLM_F_REQUEST, SOL_NETLINK, - nlmsg, ) from pyroute2.netlink.core import ( + AsyncCoreSocket, CoreDatagramProtocol, - CoreSocket, CoreSocketSpec, + SyncMixin, ) -from pyroute2.netlink.exceptions import ( - ChaoticException, - NetlinkDumpInterrupted, - NetlinkError, -) +from pyroute2.netlink.exceptions import ChaoticException, NetlinkError from pyroute2.netlink.marshal import Marshal from pyroute2.plan9.client import Plan9ClientSocket @@ -163,15 +158,17 @@ def set_target(self, context, value): return {'target': value} def set_netns(self, context, value): + if 'target' in context: + return {'netns': 'value'} return {'target': value, 'netns': value} def set_pid(self, context, value): if value is None: return {'pid': os.getpid() & 0x3FFFFF, 'port': context['port']} elif value == 0: - return {'pid': os.getpid()} + return {'pid': os.getpid(), 'port': 0} else: - return {'pid': value} + return {'pid': value, 'port': 0} def set_port(self, context, value): if isinstance(value, int): @@ -179,15 +176,11 @@ def set_port(self, context, value): class NetlinkSocketSpec(CoreSocketSpec): - def __init__(self, spec=None): - super().__init__(spec) - default = {'pid': 0, 'epid': 0, 'port': 0, 'uname': config.uname} - self.status.add_filter(NetlinkSocketSpecFilter()) - self.status.update(default) - self.status.update(self) + defaults = {'pid': 0, 'epid': 0, 'port': 0, 'uname': config.uname} + status_filters = [NetlinkSocketSpecFilter] -class NetlinkSocket(CoreSocket): +class AsyncNetlinkSocket(AsyncCoreSocket): ''' Netlink socket ''' @@ -265,10 +258,6 @@ def uname(self): def groups(self): return self.status['groups'] - @property - def port(self): - return self.status['port'] - @property def pid(self): return self.status['pid'] @@ -277,7 +266,11 @@ def pid(self): def target(self): return self.status['target'] - def setup_socket(self, sock=None): + @property + def asyncore(self): + return self + + async def setup_socket(self, sock=None): """Re-init a netlink socket.""" if self.status['use_socket']: return self.use_socket @@ -285,7 +278,10 @@ def setup_socket(self, sock=None): if sock is not None: sock.close() sock = config.SocketBase( - AF_NETLINK, SOCK_DGRAM, self.spec['family'], self.spec['fileno'] + AF_NETLINK, + SOCK_DGRAM, + self.spec['family'], + self.spec['fileno'] or self.local.fileno, ) sock.setsockopt(SOL_SOCKET, SO_SNDBUF, self.status['sndbuf']) sock.setsockopt(SOL_SOCKET, SO_RCVBUF, self.status['rcvbuf']) @@ -295,25 +291,9 @@ def setup_socket(self, sock=None): sock.setsockopt(SOL_NETLINK, NETLINK_LISTEN_ALL_NSID, 1) if self.status['strict_check']: sock.setsockopt(SOL_NETLINK, NETLINK_GET_STRICT_CHK, 1) + return sock - class Bala: - def __init__(self, sock): - self._socket = sock - - def ignore(self, *argv, **kwarg): - print("ignore close") - import traceback - - traceback.print_stack() - - def __getattr__(self, attr): - if attr == 'close': - return self.ignore - return getattr(self._socket, attr) - - return Bala(sock) - - def bind(self, groups=0, pid=None, **kwarg): + async def bind(self, groups=0, pid=None, **kwarg): ''' Bind the socket to given multicast groups, using given pid. @@ -322,24 +302,19 @@ def bind(self, groups=0, pid=None, **kwarg): - If pid == 0, use process' pid - If pid == , use the value instead of pid ''' - + await self.ensure_socket() self.status['groups'] = groups # if we have pre-defined port, use it strictly - if self.status.get('port') is not None: - self.socket.bind((self.status['epid'], self.status['groups'])) - else: + self.status['pid'] = pid + if pid is None: for port in range(1024): try: - self.status['port'] = port - self.socket.bind( - (self.status['epid'], self.status['groups']) - ) + self.socket.bind((self.port, self.status['groups'])) break except Exception as e: # create a new underlying socket -- on kernel 4 # one failed bind() makes the socket useless log.error(e) - self.restart_base_socket() else: raise KeyError('no free address available') @@ -357,66 +332,51 @@ def enqueue(self, data, addr): def compile(self): return CompileContext(self) - def _send_batch(self, msgs, addr=(0, 0)): - with self.backlog_lock: - for msg in msgs: - self.backlog[msg['header']['sequence_number']] = [] - # We have locked the message locks in the caller already. - data = bytearray() - for msg in msgs: - if not isinstance(msg, nlmsg): - msg_class = self.marshal.msg_map[msg['header']['type']] - msg = msg_class(msg) - msg.reset() - msg.encode() - data += msg.data - if self.compiled is not None: - return self.compiled.append(data) - self._sock.sendto(data, addr) - - def nlm_request_batch(self, msgs, noraise=False): - """ - This function is for messages which are expected to have side effects. - Do not blindly retry in case of errors as this might duplicate them. - """ - expected_responses = [] - acquired = 0 - seqs = self.addr_pool.alloc_multi(len(msgs)) - try: - for seq in seqs: - self.lock[seq].acquire() - acquired += 1 - for seq, msg in zip(seqs, msgs): - msg['header']['sequence_number'] = seq - if 'pid' not in msg['header']: - msg['header']['pid'] = self.epid or os.getpid() - if (msg['header']['flags'] & NLM_F_ACK) or ( - msg['header']['flags'] & NLM_F_DUMP - ): - expected_responses.append(seq) - self._send_batch(msgs) - if self.compiled is not None: - for data in self.compiled: - yield data - else: - for seq in expected_responses: - for msg in self.get(msg_seq=seq, noraise=noraise): - if msg['header']['flags'] & NLM_F_DUMP_INTR: - # Leave error handling to the caller - raise NetlinkDumpInterrupted() - yield msg - finally: - # Release locks in reverse order. - for seq in seqs[acquired - 1 :: -1]: - self.lock[seq].release() - - with self.backlog_lock: - for seq in seqs: - # Clear the backlog. We may have raised an error - # causing the backlog to not be consumed entirely. - if seq in self.backlog: - del self.backlog[seq] - self.addr_pool.free(seq, ban=0xFF) + def make_request_type(self, command, command_map): + if isinstance(command, basestring): + return (lambda x: (x[0], self.make_request_flags(x[1])))( + command_map[command] + ) + elif isinstance(command, int): + return command, self.make_request_flags('create') + elif isinstance(command, (list, tuple)): + return command + else: + raise TypeError('allowed command types: int, str, list, tuple') + + def make_request_flags(self, mode): + flags = { + 'dump': NLM_F_REQUEST | NLM_F_DUMP, + 'get': NLM_F_REQUEST | NLM_F_ACK, + 'req': NLM_F_REQUEST | NLM_F_ACK, + } + flags['create'] = flags['req'] | NLM_F_CREATE | NLM_F_EXCL + flags['append'] = flags['req'] | NLM_F_CREATE | NLM_F_APPEND + flags['change'] = flags['req'] | NLM_F_REPLACE + flags['replace'] = flags['change'] | NLM_F_CREATE + + return flags[mode] | ( + NLM_F_ECHO + if (self.status['nlm_echo'] and mode not in ('get', 'dump')) + else 0 + ) + + async def nlm_request( + self, + msg, + msg_type, + msg_flags=NLM_F_REQUEST | NLM_F_DUMP, + terminate=None, + callback=None, + parser=None, + ): + request = NetlinkRequest( + self, msg, terminate=terminate, callback=callback + ) + request.msg['header']['type'] = msg_type + request.msg['header']['flags'] = msg_flags + await request.send() + return request.response() class NetlinkRequest: @@ -435,37 +395,42 @@ def __init__( self, sock, msg, - command, - command_map, - dump_filter, - request_filter, + command=None, + command_map=None, + dump_filter=None, + request_filter=None, terminate=None, callback=None, ): self.sock = sock self.addr_pool = sock.addr_pool self.status = sock.status + self.port = sock.port self.marshal = sock.marshal # if not isinstance(msg, nlmsg): # msg_class = self.marshal.msg_map[msg_type] # msg = msg_class(msg) - msg_type, msg_flags = self.calculate_request_type(command, command_map) + if command_map is not None: + msg_type, msg_flags = self.calculate_request_type( + command, command_map + ) + msg['header']['type'] = msg_type + msg['header']['flags'] = msg_flags self.msg_seq = self.addr_pool.alloc() - msg['header']['type'] = msg_type - msg['header']['flags'] = msg_flags msg['header']['sequence_number'] = self.msg_seq - msg['header']['pid'] = self.status['epid'] or os.getpid() + msg['header']['pid'] = self.port or os.getpid() msg.reset() # set fields - for field in msg.fields: - msg[field[0]] = request_filter.get_value( - field[0], default=0, mode='field' - ) - # attach NLAs - for key, value in request_filter.items(): - nla = type(msg).name2nla(key) - if msg.valid_nla(nla) and value is not None: - msg['attrs'].append([nla, value]) + if request_filter is not None: + for field in msg.fields: + msg[field[0]] = request_filter.get_value( + field[0], default=0, mode='field' + ) + # attach NLAs + for key, value in request_filter.items(): + nla = type(msg).name2nla(key) + if msg.valid_nla(nla) and value is not None: + msg['attrs'].append([nla, value]) self.msg = msg self.dump_filter = dump_filter self.terminate = terminate @@ -512,14 +477,14 @@ def match_one_message(self, msg): return all(matches) async def send(self): - print(" ??? ", self.msg) + await self.sock.ensure_socket() self.msg.encode() self.sock.msg_queue.ensure(self.msg_seq) count = 0 e = None for count in range(30): try: - return self.sock.send(self.msg.data) + return self.sock.asyncore.send(self.msg.data) except NetlinkError as e: if e.code != errno.EBUSY: break @@ -533,7 +498,7 @@ async def send(self): raise e async def response(self): - async for msg in self.sock.async_get( + async for msg in self.sock.asyncore.get( msg_seq=self.msg_seq, terminate=self.terminate, callback=self.callback, @@ -549,6 +514,10 @@ async def response(self): IPCSocketPair = collections.namedtuple('IPCSocketPair', ('server', 'client')) +class NetlinkSocket(AsyncCoreSocket, SyncMixin): + pass + + class IPCSocket(NetlinkSocket): def setup_socket(self): diff --git a/pyroute2/netlink/rtnl/iprsocket.py b/pyroute2/netlink/rtnl/iprsocket.py index a0c16ab7a..5de75baf7 100644 --- a/pyroute2/netlink/rtnl/iprsocket.py +++ b/pyroute2/netlink/rtnl/iprsocket.py @@ -3,9 +3,9 @@ from pyroute2.common import AddrPool, Namespace from pyroute2.netlink import NETLINK_ROUTE, rtnl from pyroute2.netlink.nlsocket import ( + AsyncNetlinkSocket, BatchSocket, ChaoticNetlinkSocket, - NetlinkSocket, ) from pyroute2.netlink.proxy import NetlinkProxy from pyroute2.netlink.rtnl.marshal import MarshalRtnl @@ -18,7 +18,7 @@ from pyroute2.netlink.rtnl.probe_msg import proxy_newprobe -class IPRSocket(NetlinkSocket): +class IPRSocket(AsyncNetlinkSocket): ''' The simplest class, that connects together the netlink parser and a generic Python socket implementation. Provides method get() to @@ -90,11 +90,11 @@ def __init__(self, *argv, **kwarg): rtnl.RTM_NEWPROBE: proxy_newprobe, } super().__init__(NETLINK_ROUTE, *argv[1:], **kwarg) - if self.status['groups'] == 0: + if self.spec['groups'] == 0: self.spec['groups'] = rtnl.RTMGRP_DEFAULTS - def bind(self, groups=None, **kwarg): - super().bind( + async def bind(self, groups=None, **kwarg): + return await super().bind( groups if groups is not None else self.status['groups'], **kwarg ) diff --git a/pyroute2/plan9/__init__.py b/pyroute2/plan9/__init__.py index 7dfc65af4..3f13fcb97 100644 --- a/pyroute2/plan9/__init__.py +++ b/pyroute2/plan9/__init__.py @@ -60,6 +60,8 @@ def decode_from(data, offset): @staticmethod def encode_into(data, offset, value): + if not isinstance(value, (tuple, list)): + value = [] data.extend([0] * struct.calcsize(header)) struct.pack_into(header, data, offset, len(value)) offset += struct.calcsize(header)