diff --git a/pkg/base1/test-dbus-common.js b/pkg/base1/test-dbus-common.js index cd795fb1dc6b..61038eb9ef67 100644 --- a/pkg/base1/test-dbus-common.js +++ b/pkg/base1/test-dbus-common.js @@ -232,7 +232,7 @@ export function common_dbus_tests(channel_options, bus_name) { // eslint-disable assert.ok(false, "should not be reached"); } catch (ex) { assert.equal(ex.name, "org.freedesktop.DBus.Error.UnknownMethod", "error name"); - assert.equal(ex.message, "Method UnimplementedMethod is not implemented on interface com.redhat.Cockpit.DBusTests.Frobber", "error message"); + assert.equal(ex.message, "Unknown method UnimplementedMethod or interface com.redhat.Cockpit.DBusTests.Frobber.", "error message"); } }); diff --git a/pkg/base1/test-http.js b/pkg/base1/test-http.js index db508143560e..20fd4188f7fa 100644 --- a/pkg/base1/test-http.js +++ b/pkg/base1/test-http.js @@ -210,6 +210,10 @@ QUnit.test("headers", assert => { .get("/mock/headers", null, { Header1: "booo", Header2: "yay value" }) .response((status, headers) => { assert.equal(status, 201, "status code"); + + delete headers['Content-Type']; + delete headers.Date; + delete headers.Server; assert.deepEqual(headers, { Header1: "booo", Header2: "yay value", @@ -248,6 +252,10 @@ QUnit.test("connection headers", assert => { .get("/mock/headers", null, { Header2: "yay value", Header0: "extra" }) .response((status, headers) => { assert.equal(status, 201, "status code"); + + delete headers['Content-Type']; + delete headers.Date; + delete headers.Server; assert.deepEqual(headers, { Header0: "extra", Header1: "booo", diff --git a/src/cockpit/jsonutil.py b/src/cockpit/jsonutil.py index 7df905c4e6c2..f4e2f1f21b54 100644 --- a/src/cockpit/jsonutil.py +++ b/src/cockpit/jsonutil.py @@ -83,6 +83,12 @@ def get_str(obj: JsonObject, key: str, default: Union[DT, _Empty] = _empty) -> U return _get(obj, lambda v: typechecked(v, str), key, default) +def get_str_map(obj: JsonObject, key: str, default: DT | _Empty = _empty) -> DT | Mapping[str, str]: + def as_str_map(value: JsonValue) -> Mapping[str, str]: + return {key: typechecked(value, str) for key, value in typechecked(value, dict).items()} + return _get(obj, as_str_map, key, default) + + def get_str_or_none(obj: JsonObject, key: str, default: Optional[str]) -> Optional[str]: return _get(obj, lambda v: None if v is None else typechecked(v, str), key, default) diff --git a/test/common/tap-cdp b/test/common/tap-cdp index 9efe34bfb682..dd0e512b07e6 100755 --- a/test/common/tap-cdp +++ b/test/common/tap-cdp @@ -18,9 +18,7 @@ # import argparse -import os import re -import subprocess import sys from cdp import CDP @@ -28,31 +26,8 @@ from cdp import CDP tap_line_re = re.compile(r'^(ok [0-9]+|not ok [0-9]+|bail out!|[0-9]+\.\.[0-9]+|# )', re.IGNORECASE) parser = argparse.ArgumentParser(description="A CDP driver for QUnit which outputs TAP") -parser.add_argument("server", help="path to the test-server and the test page to run", nargs=argparse.REMAINDER) - -# Strip prefix from url -# We need this to compensate for automake test generation behavior: -# The tests are called with the path (relative to the build directory) of the testfile, -# but from the build directory. Some tests make assumptions regarding the structure of the -# filename. In order to make sure that they receive the same name, regardless of actual -# build directory location, we need to strip that prefix (path from build to source directory) -# from the filename -parser.add_argument("--strip", dest="strip", help="strip prefix from test file paths") - -opts = parser.parse_args() - -# argparse sometimes forgets to remove this on argparse.REMAINDER args -if opts.server[0] == '--': - opts.server = opts.server[1:] - -# The test file is the last argument, but 'server' might contain arbitrary -# amount of options. We cannot express this with argparse, so take it apart -# manually. -opts.test = opts.server[-1] -opts.server = opts.server[:-1] - -if opts.strip and opts.test.startswith(opts.strip): - opts.test = opts.test[len(opts.strip):] +parser.add_argument("url", help="url to the test to run") +args = parser.parse_args() cdp = CDP("C.utf8") @@ -62,21 +37,7 @@ except SystemError: print('1..0 # skip web browser not found') sys.exit(0) -# pass the address through a separate fd, so that we can see g_debug() messages (which go to stdout) -(addr_r, addr_w) = os.pipe() -env = os.environ.copy() -env["TEST_SERVER_ADDRESS_FD"] = str(addr_w) - -server = subprocess.Popen(opts.server, - stdin=subprocess.DEVNULL, - pass_fds=(addr_w,), - close_fds=True, - env=env) -os.close(addr_w) -address = os.read(addr_r, 1000).decode() -os.close(addr_r) - -cdp.invoke("Page.navigate", url=address + '/' + opts.test) +cdp.invoke("Page.navigate", url=args.url) success = True ignore_resource_errors = False @@ -109,9 +70,6 @@ for t, message in cdp.read_log(): else: print(message, file=sys.stderr) - -server.terminate() -server.wait() cdp.kill() if not success: diff --git a/test/pytest/mockdbusservice.py b/test/pytest/mockdbusservice.py new file mode 100644 index 000000000000..86db89d9afa2 --- /dev/null +++ b/test/pytest/mockdbusservice.py @@ -0,0 +1,171 @@ +import asyncio +import contextlib +import logging +import math +from collections.abc import AsyncIterator +from typing import Iterator + +from cockpit._vendor import systemd_ctypes + +logger = logging.getLogger(__name__) + + +# No introspection, manual handling of method calls +class borkety_Bork(systemd_ctypes.bus.BaseObject): + def message_received(self, message: systemd_ctypes.bus.BusMessage) -> bool: + signature = message.get_signature(True) # noqa:FBT003 + body = message.get_body() + logger.debug('got Bork message: %s %r', signature, body) + + if message.get_member() == 'Echo': + message.reply_method_return(signature, *body) + return True + + return False + + +class com_redhat_Cockpit_DBusTests_Frobber(systemd_ctypes.bus.Object): + finally_normal_name = systemd_ctypes.bus.Interface.Property('s', 'There aint no place like home') + readonly_property = systemd_ctypes.bus.Interface.Property('s', 'blah') + aay = systemd_ctypes.bus.Interface.Property('aay', [], name='aay') + ag = systemd_ctypes.bus.Interface.Property('ag', [], name='ag') + ao = systemd_ctypes.bus.Interface.Property('ao', [], name='ao') + as_ = systemd_ctypes.bus.Interface.Property('as', [], name='as') + ay = systemd_ctypes.bus.Interface.Property('ay', b'ABCabc\0', name='ay') + b = systemd_ctypes.bus.Interface.Property('b', value=False, name='b') + d = systemd_ctypes.bus.Interface.Property('d', 43, name='d') + g = systemd_ctypes.bus.Interface.Property('g', '', name='g') + i = systemd_ctypes.bus.Interface.Property('i', 0, name='i') + n = systemd_ctypes.bus.Interface.Property('n', 0, name='n') + o = systemd_ctypes.bus.Interface.Property('o', '/', name='o') + q = systemd_ctypes.bus.Interface.Property('q', 0, name='q') + s = systemd_ctypes.bus.Interface.Property('s', '', name='s') + t = systemd_ctypes.bus.Interface.Property('t', 0, name='t') + u = systemd_ctypes.bus.Interface.Property('u', 0, name='u') + x = systemd_ctypes.bus.Interface.Property('x', 0, name='x') + y = systemd_ctypes.bus.Interface.Property('y', 42, name='y') + + test_signal = systemd_ctypes.bus.Interface.Signal('i', 'as', 'ao', 'a{s(ii)}') + + @systemd_ctypes.bus.Interface.Method('', 'i') + def request_signal_emission(self, which_one: int) -> None: + del which_one + + self.test_signal( + 43, + ['foo', 'frobber'], + ['/foo', '/foo/bar'], + {'first': (42, 42), 'second': (43, 43)} + ) + + @systemd_ctypes.bus.Interface.Method('s', 's') + def hello_world(self, greeting: str) -> str: + return f"Word! You said `{greeting}'. I'm Skeleton, btw!" + + @systemd_ctypes.bus.Interface.Method('', '') + async def never_return(self) -> None: + await asyncio.sleep(1000000) + + @systemd_ctypes.bus.Interface.Method( + ['y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'o', 'g', 'ay'], + ['y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'o', 'g', 'ay'] + ) + def test_primitive_types( + self, + val_byte, val_boolean, + val_int16, val_uint16, val_int32, val_uint32, val_int64, val_uint64, + val_double, + val_string, val_objpath, val_signature, + val_bytestring + ): + return [ + val_byte + 10, + not val_boolean, + 100 + val_int16, + 1000 + val_uint16, + 10000 + val_int32, + 100000 + val_uint32, + 1000000 + val_int64, + 10000000 + val_uint64, + val_double / math.pi, + f"Word! You said `{val_string}'. Rock'n'roll!", + f"/modified{val_objpath}", + f"assgit{val_signature}", + b"bytestring!\xff\0" + ] + + @systemd_ctypes.bus.Interface.Method( + ['s'], + ["a{ss}", "a{s(ii)}", "(iss)", "as", "ao", "ag", "aay"] + ) + def test_non_primitive_types( + self, + dict_s_to_s, + dict_s_to_pairs, + a_struct, + array_of_strings, + array_of_objpaths, + array_of_signatures, + array_of_bytestrings + ): + return ( + f'{dict_s_to_s}{dict_s_to_pairs}{a_struct}' + f'array_of_strings: [{", ".join(array_of_strings)}] ' + f'array_of_objpaths: [{", ".join(array_of_objpaths)}] ' + f'array_of_signatures: [signature {", ".join(f"'{sig}'" for sig in array_of_signatures)}] ' + f'array_of_bytestrings: [{", ".join(x[:-1].decode() for x in array_of_bytestrings)}] ' + ) + + +@contextlib.contextmanager +def mock_service_export(bus: systemd_ctypes.Bus) -> Iterator[None]: + slots = [ + bus.add_object('/otree/frobber', com_redhat_Cockpit_DBusTests_Frobber()), + bus.add_object('/otree/different', com_redhat_Cockpit_DBusTests_Frobber()), + bus.add_object('/bork', borkety_Bork()) + ] + + yield + + for slot in slots: + slot.cancel() + + +@contextlib.asynccontextmanager +async def well_known_name(bus: systemd_ctypes.Bus, name: str, flags: int = 0) -> AsyncIterator[None]: + result, = await bus.call_method_async( + 'org.freedesktop.DBus', '/org/freedesktop/DBus', 'org.freedesktop.DBus', 'RequestName', 'su', name, flags + ) + if result != 1: + raise RuntimeError(f'Cannot register name {name}: {result}') + + try: + yield + + finally: + result, = await bus.call_method_async( + 'org.freedesktop.DBus', '/org/freedesktop/DBus', 'org.freedesktop.DBus', 'ReleaseName', 's', name + ) + if result != 1: + raise RuntimeError(f'Cannot release name {name}: {result}') + + +@contextlib.asynccontextmanager +async def mock_dbus_service_on_user_bus() -> AsyncIterator[None]: + user = systemd_ctypes.Bus.default_user() + async with ( + well_known_name(user, 'com.redhat.Cockpit.DBusTests.Test'), + well_known_name(user, 'com.redhat.Cockpit.DBusTests.Second'), + ): + with mock_service_export(user): + yield + + +async def main(): + async with mock_dbus_service_on_user_bus(): + print('Mock service running. Ctrl+C to exit.') + await asyncio.sleep(2 << 30) # "a long time." + + +if __name__ == '__main__': + systemd_ctypes.run_async(main()) diff --git a/test/pytest/mockwebserver.py b/test/pytest/mockwebserver.py new file mode 100644 index 000000000000..7f55cb3e77f6 --- /dev/null +++ b/test/pytest/mockwebserver.py @@ -0,0 +1,466 @@ +# This file is part of Cockpit. +# +# Copyright (C) 2024 Red Hat, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# https://github.com/astral-sh/ruff/issues/10980#issuecomment-2219615329 +# ruff: noqa: RUF029 + +import argparse +import asyncio +import binascii +import contextlib +import json +import logging +import os +import socket +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping +from pathlib import Path +from typing import ClassVar, NamedTuple, Self + +import aiohttp +from aiohttp import web + +from cockpit._vendor import systemd_ctypes +from cockpit.bridge import Bridge +from cockpit.jsonutil import JsonObject, JsonValue, create_object, get_enum, get_int, get_str, get_str_map +from cockpit.protocol import CockpitProblem, CockpitProtocolError + +from .mockdbusservice import mock_dbus_service_on_user_bus + + +class TextChannelOrigin(NamedTuple): + enqueue: Callable[[str | None], None] + + +class BinaryChannelOrigin(NamedTuple): + enqueue: Callable[[str | bytes | None], None] + + +class ExternalChannelOrigin(NamedTuple): + enqueue: Callable[[JsonObject | bytes], None] + + +ChannelOrigin = TextChannelOrigin | BinaryChannelOrigin | ExternalChannelOrigin + + +class MultiplexTransport(asyncio.Transport): + transports: ClassVar[dict[str, Self]] = {} + + def __init__(self, protocol: asyncio.Protocol, origin: ChannelOrigin): + self.csrf_token = 'hunter2' + self.origins: dict[str | None, ChannelOrigin] = {None: origin} + self.channel_sequence = 0 + self.protocol = protocol + self.protocol.connection_made(self) + + MultiplexTransport.transports[self.csrf_token] = self + + def write(self, data: bytes) -> None: + # We know that cockpit.protocol always writes complete frames + header, _, frame = data.partition(b'\n') + assert int(header) == len(frame) + + channel_id, _, body = frame.partition(b'\n') + if channel_id: + # data message on the named channel + origin = self.origins.get(channel_id.decode()) + match origin: + case BinaryChannelOrigin(enqueue): + enqueue(frame) + case TextChannelOrigin(enqueue): + enqueue(frame.decode()) + case ExternalChannelOrigin(enqueue): + enqueue(body) + + else: + # control message (channel=None for transport control) + message = json.loads(body) + channel = get_str(message, 'channel', None) + origin = self.origins.get(channel) + + match origin: + case BinaryChannelOrigin(enqueue) | TextChannelOrigin(enqueue): + enqueue(frame.decode()) + case ExternalChannelOrigin(enqueue): + enqueue(message) + + print(message) + if origin is not None and get_str(message, 'command') == 'close': + del self.origins[channel] + + def register_origin(self, origin: ChannelOrigin, channel: str | None = None) -> str: + # normal channels get their IDs allocated in cockpit.js + + if channel is None: + # external channels get their IDs allocated by us + channel = f'external{self.channel_sequence}' + self.channel_sequence += 1 + + self.origins[channel] = origin + + return channel + + def data_received(self, data: bytes) -> None: + # cockpit.protocol expects a frame length header + header = f'{len(data)}\n'.encode() + self.protocol.data_received(header + data) + + def control_received(self, message: JsonObject) -> None: + self.data_received(b'\n' + json.dumps(message).encode()) + + def close(self) -> None: + pass + + +class CockpitWebSocket(web.WebSocketResponse): + def __init__(self): + self.outgoing_queue = asyncio.Queue[str | bytes | None]() + super().__init__(protocols=['cockpit1']) + + async def send_control(self, _msg: JsonObject | None = None, **kwargs: JsonValue) -> None: + await self.send_str('\n' + json.dumps(create_object(_msg, kwargs))) + + async def process_outgoing_queue(self, queue: asyncio.Queue[str | bytes | None]) -> None: + while True: + item = await queue.get() + if isinstance(item, str): + await self.send_str(item) + elif isinstance(item, bytes): + await self.send_bytes(item) + else: + break + + async def communicate(self, request: web.Request) -> None: + text_origin = TextChannelOrigin(self.outgoing_queue.put_nowait) + binary_origin = BinaryChannelOrigin(self.outgoing_queue.put_nowait) + + try: + bridge = Bridge(argparse.Namespace(privileged=False, beipack=False)) + transport = MultiplexTransport(bridge, text_origin) + + # wait for the bridge to send its "init" + bridge_init = await self.outgoing_queue.get() + del bridge_init + + # send our "init" to the websocket + await self.prepare(request) + await self.send_control( + command='init', version=1, host='localhost', + channel_seed='test-server', csrf_token=transport.csrf_token, + capabilities=['multi', 'credentials', 'binary'], + system={'version': '0'} + ) + + # receive "init" from the websocket + try: + assert await self.receive_json() == {'command': 'init', 'version': 1} + except (TypeError, json.JSONDecodeError, AssertionError) as exc: + raise CockpitProtocolError('expected init message') from exc + + # send "init" to the bridge + # TODO: explicit-superuser handling + transport.data_received(b'\n' + json.dumps({ + "command": "init", + "version": 1, + "host": "localhost" + }).encode()) + + write_task = asyncio.create_task(self.process_outgoing_queue(self.outgoing_queue)) + + try: + async for msg in self: + if msg.type == aiohttp.WSMsgType.TEXT: + frame = msg.data + if frame.startswith('\n'): + control = json.loads(frame) + command = get_str(control, 'command') + channel = get_str(control, 'channel', None) + if command == 'open': + if channel is None: + raise CockpitProtocolError('open message without channel') + binary = get_enum(control, 'binary', ['raw'], None) == 'raw' + transport.register_origin(binary_origin if binary else text_origin, channel) + transport.data_received(frame.encode()) + elif msg.type == aiohttp.WSMsgType.BINARY: + transport.data_received(msg.data) + else: + raise CockpitProtocolError(f'strange websocket message {msg!s}') + finally: + self.outgoing_queue.put_nowait(None) + await write_task + + except CockpitProblem as exc: + if not self.closed: + await self.send_control(exc.get_attrs(), command='close') + + +routes = web.RouteTableDef() + + +@routes.get(r'/favicon.ico') +async def favicon_ico(request: web.Request) -> web.FileResponse: + del request + return web.FileResponse('src/branding/default/favicon.ico') + + +SPLIT_UTF8_FRAMES = [ + b"initial", + # split an é in the middle + b"first half \xc3", + b"\xa9 second half", + b"final" +] + + +@routes.get(r'/mock/expect-warnings') +@routes.get(r'/mock/dont-expect-warnings') +async def mock_expect_warnings(_request: web.Request) -> web.Response: + # no op — only for compatibility with C test-server + return web.Response(status=200, text='OK') + + +@routes.get(r'/mock/info') +async def mock_info(_request: web.Request) -> web.Response: + return web.json_response({ + 'pybridge': True, + 'skip_slow_tests': 'COCKPIT_SKIP_SLOW_TESTS' in os.environ + }) + + +@routes.get(r'/mock/stream') +async def mock_stream(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + + for i in range(10): + await response.write(f'{i} '.encode()) + + return response + + +@routes.get(r'/mock/split-utf8') +async def mock_split_utf8(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + + for chunk in SPLIT_UTF8_FRAMES: + await response.write(chunk) + + return response + + +@routes.get(r'/mock/truncated-utf8') +async def mock_truncated_utf8(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + + for chunk in SPLIT_UTF8_FRAMES[0:2]: + await response.write(chunk) + + return response + + +@routes.get(r'/mock/headers') +async def mock_headers(request: web.Request) -> web.Response: + headers = {k: v for k, v in request.headers.items() if k.startswith('Header')} + headers['Header3'] = 'three' + headers['Header4'] = 'marmalade' + + return web.Response(status=201, text='Yoo Hoo', headers=headers) + + +@routes.get(r'/mock/host') +async def mock_host(request: web.Request) -> web.Response: + return web.Response(status=201, text='Yoo Hoo', headers={'Host': request.headers['Host']}) + + +@routes.get(r'/mock/headonly') +async def mock_headonly(request: web.Request) -> web.Response: + if request.method != 'HEAD': + return web.Response(status=400, reason="Only HEAD allowed on this path") + + input_data = request.headers.get('InputData') + if not input_data: + return web.Response(status=400, reason="Requires InputData header") + + return web.Response(status=200, text='OK', headers={'InputDataLength': str(len(input_data))}) + + +@routes.get(r'/mock/qs') +async def mock_qs(request: web.Request) -> web.Response: + return web.Response(text=request.query_string.replace(' ', '+')) + + +@routes.get(r'/cockpit/channel/{csrf_token}') +async def cockpit_channel(request: web.Request) -> web.StreamResponse: + try: + transport = MultiplexTransport.transports[request.match_info['csrf_token']] + except KeyError: + return web.Response(status=404) + + # Decode the request + try: + options = json.loads(binascii.a2b_base64(request.query_string)) + except (json.JSONDecodeError, binascii.Error) as exc: + return web.Response(status=400, reason=f'Invalid query string {exc!s}') + + binary = get_enum(options, 'binary', ['raw'], None) == 'raw' + websocket = request.headers.get('Upgrade', '').lower() == 'websocket' + + # Open the channel, requesting data send to our queue + queue = asyncio.Queue[JsonObject | bytes]() + channel = transport.register_origin(ExternalChannelOrigin(queue.put_nowait)) + transport.control_received({**options, 'command': 'open', 'channel': channel, 'flow-control': True}) + + # The first thing the channel sends back will be 'ready' or 'close' + open_result = await queue.get() + assert isinstance(open_result, Mapping) + if get_str(open_result, 'command') != 'ready': + return web.json_response(open_result, status=400, reason='Failed to open channel') + + # Start streaming the result. + if websocket: + response = web.WebSocketResponse() + await response.prepare(request) + + else: + # Send the 'external' field back as the HTTP headers... + headers = {**get_str_map(options, 'external', {})} + + if 'Content-Type' not in headers: + headers['Content-Type'] = 'application/octet-stream' if binary else 'text/plain' + + # ...plus this, if we have it. + if size_hint := get_int(open_result, 'size-hint', None): + headers['Content-Length'] = f'{size_hint}' + + response = web.StreamResponse(status=200, headers=headers) + await response.prepare(request) + + # Now, handle the data we receive + while item := await queue.get(): + match item: + case Mapping(): + match get_str(item, 'command'): + case 'ping': + transport.control_received({**item, 'command': 'pong'}) + case 'close' | 'done': + break + + case bytes(): + await response.write(item) + + return response + + +@routes.get(r'/cockpit/socket') +async def cockpit_socket(request: web.Request) -> web.WebSocketResponse: + ws = CockpitWebSocket() + await ws.communicate(request) + return ws + + +@routes.get('/') +async def index(_request: web.Request) -> web.Response: + cases = Path('qunit').rglob('test-*.html') + + result = ( + """ + + + Test cases + + + + + + """ + ) + + return web.Response(text=result, content_type='text/html') + + +@routes.get(r'/{name:(pkg|dist|qunit)/.+}') +async def serve_file(request: web.Request) -> web.FileResponse: + path = Path('.') / request.match_info['name'] + return web.FileResponse(path) + + +COMMON_HEADERS = { + "Cross-Origin-Resource-Policy": "same-origin", + "Referrer-Policy": "no-referrer", + "X-Content-Type-Options": "nosniff", + "X-DNS-Prefetch-Control": "off", + "X-Frame-Options": "sameorigin", +} + + +@web.middleware +async def cockpit_middleware( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] +) -> web.StreamResponse: + try: + response = await handler(request) + except web.HTTPException as ex: + response = web.Response( + status=ex.status, reason=ex.reason, text=f'

{ex.reason}

', content_type='text/html' + ) + + response.headers.update(COMMON_HEADERS) + return response + + +@contextlib.asynccontextmanager +async def mock_webserver(addr: str = '127.0.0.1', port: int = 0) -> AsyncIterator[str]: + async with mock_dbus_service_on_user_bus(): + app = web.Application(middlewares=[cockpit_middleware]) + app.add_routes(routes) + + runner = web.AppRunner(app) + await runner.setup() + + listener = socket.socket() + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listener.bind((addr, port)) + listener.listen() + site = web.SockSite(runner, listener) + await site.start() + + addr, port = listener.getsockname() + yield f'http://{addr}:{port}/' + + await runner.cleanup() + + +async def main() -> None: + parser = argparse.ArgumentParser(description="Serve a single git repository via HTTP") + parser.add_argument('--addr', '-a', default='127.0.0.1', help="Address to bind to") + parser.add_argument('--port', '-p', type=int, default=8080, help="Port number to bind to") + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG) + + async with mock_webserver(args.addr, args.port) as url: + print(f"\n {url}\n\nCtrl+C to exit.") + await asyncio.sleep(1000000) + + +if __name__ == '__main__': + systemd_ctypes.run_async(main()) diff --git a/test/pytest/test_browser.py b/test/pytest/test_browser.py index 38dc0e904d4c..53ef9309cdd6 100644 --- a/test/pytest/test_browser.py +++ b/test/pytest/test_browser.py @@ -1,10 +1,12 @@ import glob import os -import subprocess -import sys -from typing import Iterable +from typing import AsyncIterator, Iterable import pytest +import pytest_asyncio + +from .mockwebserver import mock_webserver +from .webdriver_bidi import BrowsingContext, WebdriverDriver, WebdriverSession SRCDIR = os.path.realpath(f'{__file__}/../../..') BUILDDIR = os.environ.get('abs_builddir', SRCDIR) @@ -28,40 +30,79 @@ def glob_py310(fnmatch: str, *, root_dir: str, recursive: bool = False) -> Itera yield result[prefixlen:] +@pytest_asyncio.fixture +async def browsing_session() -> AsyncIterator[WebdriverSession]: + async with WebdriverDriver.connect() as driver: + print('driver', driver) + async with driver.start_session() as session: + print('session', session) + yield session + + +@pytest_asyncio.fixture +async def tab(browsing_session: WebdriverSession) -> AsyncIterator[BrowsingContext]: + async with browsing_session.create_context() as context: + print('context', context) + yield context + + +@pytest.mark.asyncio @pytest.mark.parametrize('html', glob_py310('**/test-*.html', root_dir=f'{SRCDIR}/qunit', recursive=True)) -def test_browser(html): - if not os.path.exists(f'{BUILDDIR}/test-server'): - pytest.skip('no test-server') +async def test_browser(tab: BrowsingContext, html: str) -> None: if html in SKIP: pytest.skip() elif html in XFAIL: pytest.xfail() - if 'COVERAGE_RCFILE' in os.environ: - coverage = ['coverage', 'run', '--parallel-mode', '--module'] - else: - coverage = [] + print('TAB', tab) + + async with mock_webserver() as url: + log = await tab.session.subscribe_console() + await tab.navigate(f'{url}qunit/{html}') + + ignore_resource_errors = False + error_message = None - # Merge 2>&1 so that pytest displays an interleaved log - subprocess.run(['test/common/tap-cdp', f'{BUILDDIR}/test-server', - sys.executable, '-m', *coverage, 'cockpit.bridge', '--debug', - f'./qunit/{html}'], check=True, stderr=subprocess.STDOUT) + async for message in log: + if message.type == 'console': + print('LOG', message.text) + + if message.text == 'cockpittest-tap-done': + break + elif message.text == 'cockpittest-tap-error': + error_message = message.text + break + elif message.text == 'cockpittest-tap-expect-resource-error': + ignore_resource_errors = True + continue + elif message.text.startswith('not ok'): + error_message = message.text + + elif message.type == 'warning': + print('WARNING', message.text) + + else: + print('OTHER', message.type, message.args, message.text) + + # fail on browser level errors + if ignore_resource_errors and "Failed to load resource" in message.text: + continue + + error_message = message.text + break + + if error_message is not None: + pytest.fail(f'Test failed: {error_message}') # run test-timeformat.ts in different time zones: west/UTC/east +@pytest.mark.asyncio @pytest.mark.parametrize('tz', ['America/Toronto', 'Europe/London', 'UTC', 'Europe/Berlin', 'Australia/Sydney']) -def test_timeformat_timezones(tz): - if not os.path.exists(f'{BUILDDIR}/test-server'): - pytest.skip('no test-server') +async def test_timeformat_timezones(tz: str, monkeypatch: pytest.MonkeyPatch) -> None: # this doesn't get built in rpm/deb package build environments, similar to test_browser() built_test = './qunit/base1/test-timeformat.html' if not os.path.exists(built_test): pytest.skip(f'{built_test} not found') - env = os.environ.copy() - env['TZ'] = tz - - # Merge 2>&1 so that pytest displays an interleaved log - subprocess.run(['test/common/tap-cdp', f'{BUILDDIR}/test-server', - sys.executable, '-m', 'cockpit.bridge', '--debug', - built_test], check=True, stderr=subprocess.STDOUT, env=env) + monkeypatch.setenv('TZ', tz) + await test_browser('base1/test-timeformat.html') diff --git a/test/pytest/webdriver_bidi.py b/test/pytest/webdriver_bidi.py new file mode 100644 index 000000000000..83f2b2990a7d --- /dev/null +++ b/test/pytest/webdriver_bidi.py @@ -0,0 +1,307 @@ +# This file is part of Cockpit. +# +# Copyright (C) 2024 Red Hat, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import socket +from typing import AsyncIterator, Callable, Generic, Mapping, TypeVar, override + +import aiohttp +from yarl import URL + +from cockpit.jsonutil import ( + JsonObject, + JsonValue, + create_object, + get_dict, + get_int, + get_object, + get_objv, + get_str, + typechecked, +) + +logger = logging.getLogger(__name__) + + +class WebdriverError(RuntimeError): + pass + + +class ConsoleMessage: + def __init__(self, value: JsonObject): + print('SSSS', value) + self.level = get_str(value, 'level') + self.type = get_str(value, 'type') + self.timestamp = get_int(value, 'timestamp') + self.args = get_objv(value, 'args', dict) + self.text = get_str(value, 'text') + + +# Return a port number that was free at the time of checking +# It might be in use again by the time the function returns... +def pick_a_port() -> int: + sock = socket.socket() + try: + sock.bind(('127.0.0.1', 0)) + _ip, port = sock.getsockname() + return port + finally: + sock.close() + + +T = TypeVar('T') + + +class EventHandler: + def handle_event(self, msg: JsonObject) -> None: + raise NotImplementedError + + def eof(self) -> None: + raise NotImplementedError + + +class EventLog(Generic[T], EventHandler): + def __init__(self, ctor: Callable[[JsonObject], T]): + self.queue = asyncio.Queue[T | None]() + self.ctor = ctor + + @override + def handle_event(self, msg: JsonObject) -> None: + self.queue.put_nowait(get_object(msg, 'params', self.ctor)) + + @override + def eof(self) -> None: + self.queue.put_nowait(None) + + async def __aiter__(self) -> AsyncIterator[T]: + while True: + entry = await self.queue.get() + if entry is None: + return + yield entry + + +class BrowsingContext: + def __init__(self, session: WebdriverSession, context: str): + self.session = session + self.context = context + + async def command(self, command: str, **kwargs: JsonValue) -> JsonObject: + return await self.session.command(f'browsingContext.{command}', context=self.context, **kwargs) + + async def navigate(self, url: str, **kwargs: JsonValue) -> JsonObject: + return await self.command('navigate', url=url, **kwargs) + + async def evaluate(self, expression: str, /, *, await_promise: bool = False, **kwargs: JsonValue) -> JsonObject: + return await self.command( + "script.evaluate", + expression=expression, + awaitPromise=await_promise, + target={'context': self.context}, + **kwargs + ) + + +class WebdriverSession: + def __init__(self, ws: aiohttp.ClientWebSocketResponse): + self.ws = ws + self.pending_commands = dict[int, asyncio.Future[JsonValue]]() + self.events = dict[str, EventHandler]() + self.last_tag = 1 + + def get_tag(self): + self.last_tag += 1 + return self.last_tag + + async def command(self, method, _params: JsonObject | None = None, /, **kwargs: JsonValue) -> JsonObject: + msg = {'id': self.get_tag(), 'method': method, 'params': create_object(_params, kwargs)} + logger.debug("ws ← %r", msg) + await self.ws.send_json(msg) + future = asyncio.get_running_loop().create_future() + self.pending_commands[msg['id']] = future + return await future + + async def subscribe_event(self, name: str, ctor: Callable[[JsonObject], T]) -> EventLog[T]: + await self.command('session.subscribe', events=[name]) + log = EventLog(ctor) + self.events[name] = log + return log + + async def subscribe_console(self) -> EventLog[ConsoleMessage]: + return await self.subscribe_event('log.entryAdded', ConsoleMessage) + + @contextlib.asynccontextmanager + async def create_context(self, context_type: str = 'tab', **kwargs: JsonValue) -> AsyncIterator[BrowsingContext]: + reply = await self.command("browsingContext.create", type=context_type, **kwargs) + context = get_str(reply, 'context') + + yield BrowsingContext(self, context) + + # TODO: tear down context + + def reader_task_done(self, task: asyncio.Task[None]) -> None: + exc = task.exception() or EOFError + + for future in self.pending_commands.values(): + future.set_exception(exc) + self.pending_commands.clear() + + for handler in self.events.values(): + handler.eof() + self.events.clear() + + async def reader_task(self) -> None: + logger.debug('reader_task(%r)', self) + + async for ws_msg in self.ws: + logger.debug(' reader_task(%r) got %r', self, ws_msg) + + if ws_msg.type == aiohttp.WSMsgType.TEXT: + logger.debug("ws TEXT → %r", ws_msg) + msg = typechecked(ws_msg.json(), Mapping) + logger.debug("ws TEXT → %r", msg) + + msg_type = get_str(msg, 'type') + msg_id = get_int(msg, 'id', None) + + if msg_id is not None: + try: + pending = self.pending_commands.pop(msg_id) + except KeyError: + logger.warning('Received non-pending command response %r', msg) + continue + + logger.debug("ws_reader: resolving pending command %i", msg_id) + if msg_type == 'success': + pending.set_result(msg.get('result')) + else: + pending.set_exception(WebdriverError(f"{msg_type}: {msg['message']}")) + + elif msg_type == 'event': + method = get_str(msg, 'method') + if method in self.events: + self.events[method].handle_event(msg) + else: + logger.warning("ws_reader: unhandled event %r", msg) + + else: + logger.warning("ws_reader: unhandled message %r", msg) + + elif ws_msg.type == aiohttp.WSMsgType.ERROR: + logger.error("BiDi failure: %s", ws_msg) + break + + +class WebdriverDriver: + status: JsonObject | None = None + + def __init__(self, url: URL): + self.url = url + + async def poll_status(self, stdout: asyncio.StreamReader) -> None: + async with aiohttp.ClientSession() as session: + while self.status is None: + logger.debug('polling for status from %r', self.url) + try: + status = await session.get(self.url / 'status') + self.status = await status.json() + logger.debug(' status is %r', self.status) + + except aiohttp.ClientError as exc: + # wait for output and try again + # we don't actually care about the output + # if we don't get any output for a long time, raise + logger.debug(' %s. waiting for more input.', exc) + await asyncio.wait_for(stdout.read(1024), 5.0) + + @contextlib.asynccontextmanager + async def start_session(self) -> AsyncIterator[WebdriverSession]: + async with aiohttp.ClientSession() as client_session: + session_args = {"capabilities": { + "alwaysMatch": { + "webSocketUrl": True, + "goog:chromeOptions": {"binary": "/usr/bin/chromium-browser"}, + } + }} + + logging.debug('requesting new session %s %s', self.url, session_args) + response = await client_session.post(self.url / 'session', json=session_args) + reply = await response.json() + logging.debug(' session created: %r', reply) + session_info = get_dict(reply, 'value') + session_id = get_str(session_info, 'sessionId') + capabilities = get_dict(session_info, 'capabilities') + url = get_str(capabilities, 'webSocketUrl') + + logging.debug('connecting to websocket %s', url) + async with client_session.ws_connect(url) as ws: + logging.debug(' connected %r', ws) + session = WebdriverSession(ws) + reader_task = asyncio.create_task(session.reader_task()) + reader_task.add_done_callback(session.reader_task_done) + try: + yield session + logging.debug('delete session %r', session_id) + await client_session.delete(self.url / 'session' / session_id) + finally: + await reader_task + + @classmethod + @contextlib.asynccontextmanager + async def connect(cls) -> AsyncIterator[WebdriverDriver]: + port = pick_a_port() + url = URL(f'http://127.0.0.1:{port}') + + logger.debug('Trying to spawn driver for port %r', port) + + process = await asyncio.create_subprocess_exec( + 'chromedriver', f'--port={port}', + stdout=asyncio.subprocess.PIPE) + assert process.stdout is not None + logger.debug('webdriver process %r started', process.pid) + + try: + webdriver = cls(url) + await webdriver.poll_status(process.stdout) + yield webdriver + + finally: + logger.debug('killing webdriver process %r', process.pid) + with contextlib.suppress(ProcessLookupError): + process.kill() + logger.debug('waiting for webdriver process %r', process.pid) + await process.wait() + logger.debug('webdriver process finished') + + +async def main(): + logging.basicConfig(level=logging.DEBUG) + + async with WebdriverDriver.connect() as driver: + async with driver.start_session() as session: + async with session.create_context() as context: + await context.navigate('http://127.0.0.1:8080/') + await asyncio.sleep(100) + + +if __name__ == '__main__': + asyncio.run(main())