diff --git a/pkg/base1/test-dbus-common.js b/pkg/base1/test-dbus-common.js
index cd795fb1dc6b..61038eb9ef67 100644
--- a/pkg/base1/test-dbus-common.js
+++ b/pkg/base1/test-dbus-common.js
@@ -232,7 +232,7 @@ export function common_dbus_tests(channel_options, bus_name) { // eslint-disable
assert.ok(false, "should not be reached");
} catch (ex) {
assert.equal(ex.name, "org.freedesktop.DBus.Error.UnknownMethod", "error name");
- assert.equal(ex.message, "Method UnimplementedMethod is not implemented on interface com.redhat.Cockpit.DBusTests.Frobber", "error message");
+ assert.equal(ex.message, "Unknown method UnimplementedMethod or interface com.redhat.Cockpit.DBusTests.Frobber.", "error message");
}
});
diff --git a/pkg/base1/test-http.js b/pkg/base1/test-http.js
index db508143560e..20fd4188f7fa 100644
--- a/pkg/base1/test-http.js
+++ b/pkg/base1/test-http.js
@@ -210,6 +210,10 @@ QUnit.test("headers", assert => {
.get("/mock/headers", null, { Header1: "booo", Header2: "yay value" })
.response((status, headers) => {
assert.equal(status, 201, "status code");
+
+ delete headers['Content-Type'];
+ delete headers.Date;
+ delete headers.Server;
assert.deepEqual(headers, {
Header1: "booo",
Header2: "yay value",
@@ -248,6 +252,10 @@ QUnit.test("connection headers", assert => {
.get("/mock/headers", null, { Header2: "yay value", Header0: "extra" })
.response((status, headers) => {
assert.equal(status, 201, "status code");
+
+ delete headers['Content-Type'];
+ delete headers.Date;
+ delete headers.Server;
assert.deepEqual(headers, {
Header0: "extra",
Header1: "booo",
diff --git a/src/cockpit/jsonutil.py b/src/cockpit/jsonutil.py
index 7df905c4e6c2..f4e2f1f21b54 100644
--- a/src/cockpit/jsonutil.py
+++ b/src/cockpit/jsonutil.py
@@ -83,6 +83,12 @@ def get_str(obj: JsonObject, key: str, default: Union[DT, _Empty] = _empty) -> U
return _get(obj, lambda v: typechecked(v, str), key, default)
+def get_str_map(obj: JsonObject, key: str, default: DT | _Empty = _empty) -> DT | Mapping[str, str]:
+ def as_str_map(value: JsonValue) -> Mapping[str, str]:
+ return {key: typechecked(value, str) for key, value in typechecked(value, dict).items()}
+ return _get(obj, as_str_map, key, default)
+
+
def get_str_or_none(obj: JsonObject, key: str, default: Optional[str]) -> Optional[str]:
return _get(obj, lambda v: None if v is None else typechecked(v, str), key, default)
diff --git a/test/common/tap-cdp b/test/common/tap-cdp
index 9efe34bfb682..dd0e512b07e6 100755
--- a/test/common/tap-cdp
+++ b/test/common/tap-cdp
@@ -18,9 +18,7 @@
#
import argparse
-import os
import re
-import subprocess
import sys
from cdp import CDP
@@ -28,31 +26,8 @@ from cdp import CDP
tap_line_re = re.compile(r'^(ok [0-9]+|not ok [0-9]+|bail out!|[0-9]+\.\.[0-9]+|# )', re.IGNORECASE)
parser = argparse.ArgumentParser(description="A CDP driver for QUnit which outputs TAP")
-parser.add_argument("server", help="path to the test-server and the test page to run", nargs=argparse.REMAINDER)
-
-# Strip prefix from url
-# We need this to compensate for automake test generation behavior:
-# The tests are called with the path (relative to the build directory) of the testfile,
-# but from the build directory. Some tests make assumptions regarding the structure of the
-# filename. In order to make sure that they receive the same name, regardless of actual
-# build directory location, we need to strip that prefix (path from build to source directory)
-# from the filename
-parser.add_argument("--strip", dest="strip", help="strip prefix from test file paths")
-
-opts = parser.parse_args()
-
-# argparse sometimes forgets to remove this on argparse.REMAINDER args
-if opts.server[0] == '--':
- opts.server = opts.server[1:]
-
-# The test file is the last argument, but 'server' might contain arbitrary
-# amount of options. We cannot express this with argparse, so take it apart
-# manually.
-opts.test = opts.server[-1]
-opts.server = opts.server[:-1]
-
-if opts.strip and opts.test.startswith(opts.strip):
- opts.test = opts.test[len(opts.strip):]
+parser.add_argument("url", help="url to the test to run")
+args = parser.parse_args()
cdp = CDP("C.utf8")
@@ -62,21 +37,7 @@ except SystemError:
print('1..0 # skip web browser not found')
sys.exit(0)
-# pass the address through a separate fd, so that we can see g_debug() messages (which go to stdout)
-(addr_r, addr_w) = os.pipe()
-env = os.environ.copy()
-env["TEST_SERVER_ADDRESS_FD"] = str(addr_w)
-
-server = subprocess.Popen(opts.server,
- stdin=subprocess.DEVNULL,
- pass_fds=(addr_w,),
- close_fds=True,
- env=env)
-os.close(addr_w)
-address = os.read(addr_r, 1000).decode()
-os.close(addr_r)
-
-cdp.invoke("Page.navigate", url=address + '/' + opts.test)
+cdp.invoke("Page.navigate", url=args.url)
success = True
ignore_resource_errors = False
@@ -109,9 +70,6 @@ for t, message in cdp.read_log():
else:
print(message, file=sys.stderr)
-
-server.terminate()
-server.wait()
cdp.kill()
if not success:
diff --git a/test/pytest/mockdbusservice.py b/test/pytest/mockdbusservice.py
new file mode 100644
index 000000000000..86db89d9afa2
--- /dev/null
+++ b/test/pytest/mockdbusservice.py
@@ -0,0 +1,171 @@
+import asyncio
+import contextlib
+import logging
+import math
+from collections.abc import AsyncIterator
+from typing import Iterator
+
+from cockpit._vendor import systemd_ctypes
+
+logger = logging.getLogger(__name__)
+
+
+# No introspection, manual handling of method calls
+class borkety_Bork(systemd_ctypes.bus.BaseObject):
+ def message_received(self, message: systemd_ctypes.bus.BusMessage) -> bool:
+ signature = message.get_signature(True) # noqa:FBT003
+ body = message.get_body()
+ logger.debug('got Bork message: %s %r', signature, body)
+
+ if message.get_member() == 'Echo':
+ message.reply_method_return(signature, *body)
+ return True
+
+ return False
+
+
+class com_redhat_Cockpit_DBusTests_Frobber(systemd_ctypes.bus.Object):
+ finally_normal_name = systemd_ctypes.bus.Interface.Property('s', 'There aint no place like home')
+ readonly_property = systemd_ctypes.bus.Interface.Property('s', 'blah')
+ aay = systemd_ctypes.bus.Interface.Property('aay', [], name='aay')
+ ag = systemd_ctypes.bus.Interface.Property('ag', [], name='ag')
+ ao = systemd_ctypes.bus.Interface.Property('ao', [], name='ao')
+ as_ = systemd_ctypes.bus.Interface.Property('as', [], name='as')
+ ay = systemd_ctypes.bus.Interface.Property('ay', b'ABCabc\0', name='ay')
+ b = systemd_ctypes.bus.Interface.Property('b', value=False, name='b')
+ d = systemd_ctypes.bus.Interface.Property('d', 43, name='d')
+ g = systemd_ctypes.bus.Interface.Property('g', '', name='g')
+ i = systemd_ctypes.bus.Interface.Property('i', 0, name='i')
+ n = systemd_ctypes.bus.Interface.Property('n', 0, name='n')
+ o = systemd_ctypes.bus.Interface.Property('o', '/', name='o')
+ q = systemd_ctypes.bus.Interface.Property('q', 0, name='q')
+ s = systemd_ctypes.bus.Interface.Property('s', '', name='s')
+ t = systemd_ctypes.bus.Interface.Property('t', 0, name='t')
+ u = systemd_ctypes.bus.Interface.Property('u', 0, name='u')
+ x = systemd_ctypes.bus.Interface.Property('x', 0, name='x')
+ y = systemd_ctypes.bus.Interface.Property('y', 42, name='y')
+
+ test_signal = systemd_ctypes.bus.Interface.Signal('i', 'as', 'ao', 'a{s(ii)}')
+
+ @systemd_ctypes.bus.Interface.Method('', 'i')
+ def request_signal_emission(self, which_one: int) -> None:
+ del which_one
+
+ self.test_signal(
+ 43,
+ ['foo', 'frobber'],
+ ['/foo', '/foo/bar'],
+ {'first': (42, 42), 'second': (43, 43)}
+ )
+
+ @systemd_ctypes.bus.Interface.Method('s', 's')
+ def hello_world(self, greeting: str) -> str:
+ return f"Word! You said `{greeting}'. I'm Skeleton, btw!"
+
+ @systemd_ctypes.bus.Interface.Method('', '')
+ async def never_return(self) -> None:
+ await asyncio.sleep(1000000)
+
+ @systemd_ctypes.bus.Interface.Method(
+ ['y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'o', 'g', 'ay'],
+ ['y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'o', 'g', 'ay']
+ )
+ def test_primitive_types(
+ self,
+ val_byte, val_boolean,
+ val_int16, val_uint16, val_int32, val_uint32, val_int64, val_uint64,
+ val_double,
+ val_string, val_objpath, val_signature,
+ val_bytestring
+ ):
+ return [
+ val_byte + 10,
+ not val_boolean,
+ 100 + val_int16,
+ 1000 + val_uint16,
+ 10000 + val_int32,
+ 100000 + val_uint32,
+ 1000000 + val_int64,
+ 10000000 + val_uint64,
+ val_double / math.pi,
+ f"Word! You said `{val_string}'. Rock'n'roll!",
+ f"/modified{val_objpath}",
+ f"assgit{val_signature}",
+ b"bytestring!\xff\0"
+ ]
+
+ @systemd_ctypes.bus.Interface.Method(
+ ['s'],
+ ["a{ss}", "a{s(ii)}", "(iss)", "as", "ao", "ag", "aay"]
+ )
+ def test_non_primitive_types(
+ self,
+ dict_s_to_s,
+ dict_s_to_pairs,
+ a_struct,
+ array_of_strings,
+ array_of_objpaths,
+ array_of_signatures,
+ array_of_bytestrings
+ ):
+ return (
+ f'{dict_s_to_s}{dict_s_to_pairs}{a_struct}'
+ f'array_of_strings: [{", ".join(array_of_strings)}] '
+ f'array_of_objpaths: [{", ".join(array_of_objpaths)}] '
+ f'array_of_signatures: [signature {", ".join(f"'{sig}'" for sig in array_of_signatures)}] '
+ f'array_of_bytestrings: [{", ".join(x[:-1].decode() for x in array_of_bytestrings)}] '
+ )
+
+
+@contextlib.contextmanager
+def mock_service_export(bus: systemd_ctypes.Bus) -> Iterator[None]:
+ slots = [
+ bus.add_object('/otree/frobber', com_redhat_Cockpit_DBusTests_Frobber()),
+ bus.add_object('/otree/different', com_redhat_Cockpit_DBusTests_Frobber()),
+ bus.add_object('/bork', borkety_Bork())
+ ]
+
+ yield
+
+ for slot in slots:
+ slot.cancel()
+
+
+@contextlib.asynccontextmanager
+async def well_known_name(bus: systemd_ctypes.Bus, name: str, flags: int = 0) -> AsyncIterator[None]:
+ result, = await bus.call_method_async(
+ 'org.freedesktop.DBus', '/org/freedesktop/DBus', 'org.freedesktop.DBus', 'RequestName', 'su', name, flags
+ )
+ if result != 1:
+ raise RuntimeError(f'Cannot register name {name}: {result}')
+
+ try:
+ yield
+
+ finally:
+ result, = await bus.call_method_async(
+ 'org.freedesktop.DBus', '/org/freedesktop/DBus', 'org.freedesktop.DBus', 'ReleaseName', 's', name
+ )
+ if result != 1:
+ raise RuntimeError(f'Cannot release name {name}: {result}')
+
+
+@contextlib.asynccontextmanager
+async def mock_dbus_service_on_user_bus() -> AsyncIterator[None]:
+ user = systemd_ctypes.Bus.default_user()
+ async with (
+ well_known_name(user, 'com.redhat.Cockpit.DBusTests.Test'),
+ well_known_name(user, 'com.redhat.Cockpit.DBusTests.Second'),
+ ):
+ with mock_service_export(user):
+ yield
+
+
+async def main():
+ async with mock_dbus_service_on_user_bus():
+ print('Mock service running. Ctrl+C to exit.')
+ await asyncio.sleep(2 << 30) # "a long time."
+
+
+if __name__ == '__main__':
+ systemd_ctypes.run_async(main())
diff --git a/test/pytest/mockwebserver.py b/test/pytest/mockwebserver.py
new file mode 100644
index 000000000000..7f55cb3e77f6
--- /dev/null
+++ b/test/pytest/mockwebserver.py
@@ -0,0 +1,466 @@
+# This file is part of Cockpit.
+#
+# Copyright (C) 2024 Red Hat, Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# https://github.com/astral-sh/ruff/issues/10980#issuecomment-2219615329
+# ruff: noqa: RUF029
+
+import argparse
+import asyncio
+import binascii
+import contextlib
+import json
+import logging
+import os
+import socket
+from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
+from pathlib import Path
+from typing import ClassVar, NamedTuple, Self
+
+import aiohttp
+from aiohttp import web
+
+from cockpit._vendor import systemd_ctypes
+from cockpit.bridge import Bridge
+from cockpit.jsonutil import JsonObject, JsonValue, create_object, get_enum, get_int, get_str, get_str_map
+from cockpit.protocol import CockpitProblem, CockpitProtocolError
+
+from .mockdbusservice import mock_dbus_service_on_user_bus
+
+
+class TextChannelOrigin(NamedTuple):
+ enqueue: Callable[[str | None], None]
+
+
+class BinaryChannelOrigin(NamedTuple):
+ enqueue: Callable[[str | bytes | None], None]
+
+
+class ExternalChannelOrigin(NamedTuple):
+ enqueue: Callable[[JsonObject | bytes], None]
+
+
+ChannelOrigin = TextChannelOrigin | BinaryChannelOrigin | ExternalChannelOrigin
+
+
+class MultiplexTransport(asyncio.Transport):
+ transports: ClassVar[dict[str, Self]] = {}
+
+ def __init__(self, protocol: asyncio.Protocol, origin: ChannelOrigin):
+ self.csrf_token = 'hunter2'
+ self.origins: dict[str | None, ChannelOrigin] = {None: origin}
+ self.channel_sequence = 0
+ self.protocol = protocol
+ self.protocol.connection_made(self)
+
+ MultiplexTransport.transports[self.csrf_token] = self
+
+ def write(self, data: bytes) -> None:
+ # We know that cockpit.protocol always writes complete frames
+ header, _, frame = data.partition(b'\n')
+ assert int(header) == len(frame)
+
+ channel_id, _, body = frame.partition(b'\n')
+ if channel_id:
+ # data message on the named channel
+ origin = self.origins.get(channel_id.decode())
+ match origin:
+ case BinaryChannelOrigin(enqueue):
+ enqueue(frame)
+ case TextChannelOrigin(enqueue):
+ enqueue(frame.decode())
+ case ExternalChannelOrigin(enqueue):
+ enqueue(body)
+
+ else:
+ # control message (channel=None for transport control)
+ message = json.loads(body)
+ channel = get_str(message, 'channel', None)
+ origin = self.origins.get(channel)
+
+ match origin:
+ case BinaryChannelOrigin(enqueue) | TextChannelOrigin(enqueue):
+ enqueue(frame.decode())
+ case ExternalChannelOrigin(enqueue):
+ enqueue(message)
+
+ print(message)
+ if origin is not None and get_str(message, 'command') == 'close':
+ del self.origins[channel]
+
+ def register_origin(self, origin: ChannelOrigin, channel: str | None = None) -> str:
+ # normal channels get their IDs allocated in cockpit.js
+
+ if channel is None:
+ # external channels get their IDs allocated by us
+ channel = f'external{self.channel_sequence}'
+ self.channel_sequence += 1
+
+ self.origins[channel] = origin
+
+ return channel
+
+ def data_received(self, data: bytes) -> None:
+ # cockpit.protocol expects a frame length header
+ header = f'{len(data)}\n'.encode()
+ self.protocol.data_received(header + data)
+
+ def control_received(self, message: JsonObject) -> None:
+ self.data_received(b'\n' + json.dumps(message).encode())
+
+ def close(self) -> None:
+ pass
+
+
+class CockpitWebSocket(web.WebSocketResponse):
+ def __init__(self):
+ self.outgoing_queue = asyncio.Queue[str | bytes | None]()
+ super().__init__(protocols=['cockpit1'])
+
+ async def send_control(self, _msg: JsonObject | None = None, **kwargs: JsonValue) -> None:
+ await self.send_str('\n' + json.dumps(create_object(_msg, kwargs)))
+
+ async def process_outgoing_queue(self, queue: asyncio.Queue[str | bytes | None]) -> None:
+ while True:
+ item = await queue.get()
+ if isinstance(item, str):
+ await self.send_str(item)
+ elif isinstance(item, bytes):
+ await self.send_bytes(item)
+ else:
+ break
+
+ async def communicate(self, request: web.Request) -> None:
+ text_origin = TextChannelOrigin(self.outgoing_queue.put_nowait)
+ binary_origin = BinaryChannelOrigin(self.outgoing_queue.put_nowait)
+
+ try:
+ bridge = Bridge(argparse.Namespace(privileged=False, beipack=False))
+ transport = MultiplexTransport(bridge, text_origin)
+
+ # wait for the bridge to send its "init"
+ bridge_init = await self.outgoing_queue.get()
+ del bridge_init
+
+ # send our "init" to the websocket
+ await self.prepare(request)
+ await self.send_control(
+ command='init', version=1, host='localhost',
+ channel_seed='test-server', csrf_token=transport.csrf_token,
+ capabilities=['multi', 'credentials', 'binary'],
+ system={'version': '0'}
+ )
+
+ # receive "init" from the websocket
+ try:
+ assert await self.receive_json() == {'command': 'init', 'version': 1}
+ except (TypeError, json.JSONDecodeError, AssertionError) as exc:
+ raise CockpitProtocolError('expected init message') from exc
+
+ # send "init" to the bridge
+ # TODO: explicit-superuser handling
+ transport.data_received(b'\n' + json.dumps({
+ "command": "init",
+ "version": 1,
+ "host": "localhost"
+ }).encode())
+
+ write_task = asyncio.create_task(self.process_outgoing_queue(self.outgoing_queue))
+
+ try:
+ async for msg in self:
+ if msg.type == aiohttp.WSMsgType.TEXT:
+ frame = msg.data
+ if frame.startswith('\n'):
+ control = json.loads(frame)
+ command = get_str(control, 'command')
+ channel = get_str(control, 'channel', None)
+ if command == 'open':
+ if channel is None:
+ raise CockpitProtocolError('open message without channel')
+ binary = get_enum(control, 'binary', ['raw'], None) == 'raw'
+ transport.register_origin(binary_origin if binary else text_origin, channel)
+ transport.data_received(frame.encode())
+ elif msg.type == aiohttp.WSMsgType.BINARY:
+ transport.data_received(msg.data)
+ else:
+ raise CockpitProtocolError(f'strange websocket message {msg!s}')
+ finally:
+ self.outgoing_queue.put_nowait(None)
+ await write_task
+
+ except CockpitProblem as exc:
+ if not self.closed:
+ await self.send_control(exc.get_attrs(), command='close')
+
+
+routes = web.RouteTableDef()
+
+
+@routes.get(r'/favicon.ico')
+async def favicon_ico(request: web.Request) -> web.FileResponse:
+ del request
+ return web.FileResponse('src/branding/default/favicon.ico')
+
+
+SPLIT_UTF8_FRAMES = [
+ b"initial",
+ # split an é in the middle
+ b"first half \xc3",
+ b"\xa9 second half",
+ b"final"
+]
+
+
+@routes.get(r'/mock/expect-warnings')
+@routes.get(r'/mock/dont-expect-warnings')
+async def mock_expect_warnings(_request: web.Request) -> web.Response:
+ # no op — only for compatibility with C test-server
+ return web.Response(status=200, text='OK')
+
+
+@routes.get(r'/mock/info')
+async def mock_info(_request: web.Request) -> web.Response:
+ return web.json_response({
+ 'pybridge': True,
+ 'skip_slow_tests': 'COCKPIT_SKIP_SLOW_TESTS' in os.environ
+ })
+
+
+@routes.get(r'/mock/stream')
+async def mock_stream(request: web.Request) -> web.StreamResponse:
+ response = web.StreamResponse()
+ await response.prepare(request)
+
+ for i in range(10):
+ await response.write(f'{i} '.encode())
+
+ return response
+
+
+@routes.get(r'/mock/split-utf8')
+async def mock_split_utf8(request: web.Request) -> web.StreamResponse:
+ response = web.StreamResponse()
+ await response.prepare(request)
+
+ for chunk in SPLIT_UTF8_FRAMES:
+ await response.write(chunk)
+
+ return response
+
+
+@routes.get(r'/mock/truncated-utf8')
+async def mock_truncated_utf8(request: web.Request) -> web.StreamResponse:
+ response = web.StreamResponse()
+ await response.prepare(request)
+
+ for chunk in SPLIT_UTF8_FRAMES[0:2]:
+ await response.write(chunk)
+
+ return response
+
+
+@routes.get(r'/mock/headers')
+async def mock_headers(request: web.Request) -> web.Response:
+ headers = {k: v for k, v in request.headers.items() if k.startswith('Header')}
+ headers['Header3'] = 'three'
+ headers['Header4'] = 'marmalade'
+
+ return web.Response(status=201, text='Yoo Hoo', headers=headers)
+
+
+@routes.get(r'/mock/host')
+async def mock_host(request: web.Request) -> web.Response:
+ return web.Response(status=201, text='Yoo Hoo', headers={'Host': request.headers['Host']})
+
+
+@routes.get(r'/mock/headonly')
+async def mock_headonly(request: web.Request) -> web.Response:
+ if request.method != 'HEAD':
+ return web.Response(status=400, reason="Only HEAD allowed on this path")
+
+ input_data = request.headers.get('InputData')
+ if not input_data:
+ return web.Response(status=400, reason="Requires InputData header")
+
+ return web.Response(status=200, text='OK', headers={'InputDataLength': str(len(input_data))})
+
+
+@routes.get(r'/mock/qs')
+async def mock_qs(request: web.Request) -> web.Response:
+ return web.Response(text=request.query_string.replace(' ', '+'))
+
+
+@routes.get(r'/cockpit/channel/{csrf_token}')
+async def cockpit_channel(request: web.Request) -> web.StreamResponse:
+ try:
+ transport = MultiplexTransport.transports[request.match_info['csrf_token']]
+ except KeyError:
+ return web.Response(status=404)
+
+ # Decode the request
+ try:
+ options = json.loads(binascii.a2b_base64(request.query_string))
+ except (json.JSONDecodeError, binascii.Error) as exc:
+ return web.Response(status=400, reason=f'Invalid query string {exc!s}')
+
+ binary = get_enum(options, 'binary', ['raw'], None) == 'raw'
+ websocket = request.headers.get('Upgrade', '').lower() == 'websocket'
+
+ # Open the channel, requesting data send to our queue
+ queue = asyncio.Queue[JsonObject | bytes]()
+ channel = transport.register_origin(ExternalChannelOrigin(queue.put_nowait))
+ transport.control_received({**options, 'command': 'open', 'channel': channel, 'flow-control': True})
+
+ # The first thing the channel sends back will be 'ready' or 'close'
+ open_result = await queue.get()
+ assert isinstance(open_result, Mapping)
+ if get_str(open_result, 'command') != 'ready':
+ return web.json_response(open_result, status=400, reason='Failed to open channel')
+
+ # Start streaming the result.
+ if websocket:
+ response = web.WebSocketResponse()
+ await response.prepare(request)
+
+ else:
+ # Send the 'external' field back as the HTTP headers...
+ headers = {**get_str_map(options, 'external', {})}
+
+ if 'Content-Type' not in headers:
+ headers['Content-Type'] = 'application/octet-stream' if binary else 'text/plain'
+
+ # ...plus this, if we have it.
+ if size_hint := get_int(open_result, 'size-hint', None):
+ headers['Content-Length'] = f'{size_hint}'
+
+ response = web.StreamResponse(status=200, headers=headers)
+ await response.prepare(request)
+
+ # Now, handle the data we receive
+ while item := await queue.get():
+ match item:
+ case Mapping():
+ match get_str(item, 'command'):
+ case 'ping':
+ transport.control_received({**item, 'command': 'pong'})
+ case 'close' | 'done':
+ break
+
+ case bytes():
+ await response.write(item)
+
+ return response
+
+
+@routes.get(r'/cockpit/socket')
+async def cockpit_socket(request: web.Request) -> web.WebSocketResponse:
+ ws = CockpitWebSocket()
+ await ws.communicate(request)
+ return ws
+
+
+@routes.get('/')
+async def index(_request: web.Request) -> web.Response:
+ cases = Path('qunit').rglob('test-*.html')
+
+ result = (
+ """
+
+
+ Test cases
+
+
+
+ """ + '\n'.join(
+ f'- {case}
' for case in cases
+ ) + """
+
+
+
+ """
+ )
+
+ return web.Response(text=result, content_type='text/html')
+
+
+@routes.get(r'/{name:(pkg|dist|qunit)/.+}')
+async def serve_file(request: web.Request) -> web.FileResponse:
+ path = Path('.') / request.match_info['name']
+ return web.FileResponse(path)
+
+
+COMMON_HEADERS = {
+ "Cross-Origin-Resource-Policy": "same-origin",
+ "Referrer-Policy": "no-referrer",
+ "X-Content-Type-Options": "nosniff",
+ "X-DNS-Prefetch-Control": "off",
+ "X-Frame-Options": "sameorigin",
+}
+
+
+@web.middleware
+async def cockpit_middleware(
+ request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
+) -> web.StreamResponse:
+ try:
+ response = await handler(request)
+ except web.HTTPException as ex:
+ response = web.Response(
+ status=ex.status, reason=ex.reason, text=f'{ex.reason}
', content_type='text/html'
+ )
+
+ response.headers.update(COMMON_HEADERS)
+ return response
+
+
+@contextlib.asynccontextmanager
+async def mock_webserver(addr: str = '127.0.0.1', port: int = 0) -> AsyncIterator[str]:
+ async with mock_dbus_service_on_user_bus():
+ app = web.Application(middlewares=[cockpit_middleware])
+ app.add_routes(routes)
+
+ runner = web.AppRunner(app)
+ await runner.setup()
+
+ listener = socket.socket()
+ listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ listener.bind((addr, port))
+ listener.listen()
+ site = web.SockSite(runner, listener)
+ await site.start()
+
+ addr, port = listener.getsockname()
+ yield f'http://{addr}:{port}/'
+
+ await runner.cleanup()
+
+
+async def main() -> None:
+ parser = argparse.ArgumentParser(description="Serve a single git repository via HTTP")
+ parser.add_argument('--addr', '-a', default='127.0.0.1', help="Address to bind to")
+ parser.add_argument('--port', '-p', type=int, default=8080, help="Port number to bind to")
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.DEBUG)
+
+ async with mock_webserver(args.addr, args.port) as url:
+ print(f"\n {url}\n\nCtrl+C to exit.")
+ await asyncio.sleep(1000000)
+
+
+if __name__ == '__main__':
+ systemd_ctypes.run_async(main())
diff --git a/test/pytest/test_browser.py b/test/pytest/test_browser.py
index 38dc0e904d4c..53ef9309cdd6 100644
--- a/test/pytest/test_browser.py
+++ b/test/pytest/test_browser.py
@@ -1,10 +1,12 @@
import glob
import os
-import subprocess
-import sys
-from typing import Iterable
+from typing import AsyncIterator, Iterable
import pytest
+import pytest_asyncio
+
+from .mockwebserver import mock_webserver
+from .webdriver_bidi import BrowsingContext, WebdriverDriver, WebdriverSession
SRCDIR = os.path.realpath(f'{__file__}/../../..')
BUILDDIR = os.environ.get('abs_builddir', SRCDIR)
@@ -28,40 +30,79 @@ def glob_py310(fnmatch: str, *, root_dir: str, recursive: bool = False) -> Itera
yield result[prefixlen:]
+@pytest_asyncio.fixture
+async def browsing_session() -> AsyncIterator[WebdriverSession]:
+ async with WebdriverDriver.connect() as driver:
+ print('driver', driver)
+ async with driver.start_session() as session:
+ print('session', session)
+ yield session
+
+
+@pytest_asyncio.fixture
+async def tab(browsing_session: WebdriverSession) -> AsyncIterator[BrowsingContext]:
+ async with browsing_session.create_context() as context:
+ print('context', context)
+ yield context
+
+
+@pytest.mark.asyncio
@pytest.mark.parametrize('html', glob_py310('**/test-*.html', root_dir=f'{SRCDIR}/qunit', recursive=True))
-def test_browser(html):
- if not os.path.exists(f'{BUILDDIR}/test-server'):
- pytest.skip('no test-server')
+async def test_browser(tab: BrowsingContext, html: str) -> None:
if html in SKIP:
pytest.skip()
elif html in XFAIL:
pytest.xfail()
- if 'COVERAGE_RCFILE' in os.environ:
- coverage = ['coverage', 'run', '--parallel-mode', '--module']
- else:
- coverage = []
+ print('TAB', tab)
+
+ async with mock_webserver() as url:
+ log = await tab.session.subscribe_console()
+ await tab.navigate(f'{url}qunit/{html}')
+
+ ignore_resource_errors = False
+ error_message = None
- # Merge 2>&1 so that pytest displays an interleaved log
- subprocess.run(['test/common/tap-cdp', f'{BUILDDIR}/test-server',
- sys.executable, '-m', *coverage, 'cockpit.bridge', '--debug',
- f'./qunit/{html}'], check=True, stderr=subprocess.STDOUT)
+ async for message in log:
+ if message.type == 'console':
+ print('LOG', message.text)
+
+ if message.text == 'cockpittest-tap-done':
+ break
+ elif message.text == 'cockpittest-tap-error':
+ error_message = message.text
+ break
+ elif message.text == 'cockpittest-tap-expect-resource-error':
+ ignore_resource_errors = True
+ continue
+ elif message.text.startswith('not ok'):
+ error_message = message.text
+
+ elif message.type == 'warning':
+ print('WARNING', message.text)
+
+ else:
+ print('OTHER', message.type, message.args, message.text)
+
+ # fail on browser level errors
+ if ignore_resource_errors and "Failed to load resource" in message.text:
+ continue
+
+ error_message = message.text
+ break
+
+ if error_message is not None:
+ pytest.fail(f'Test failed: {error_message}')
# run test-timeformat.ts in different time zones: west/UTC/east
+@pytest.mark.asyncio
@pytest.mark.parametrize('tz', ['America/Toronto', 'Europe/London', 'UTC', 'Europe/Berlin', 'Australia/Sydney'])
-def test_timeformat_timezones(tz):
- if not os.path.exists(f'{BUILDDIR}/test-server'):
- pytest.skip('no test-server')
+async def test_timeformat_timezones(tz: str, monkeypatch: pytest.MonkeyPatch) -> None:
# this doesn't get built in rpm/deb package build environments, similar to test_browser()
built_test = './qunit/base1/test-timeformat.html'
if not os.path.exists(built_test):
pytest.skip(f'{built_test} not found')
- env = os.environ.copy()
- env['TZ'] = tz
-
- # Merge 2>&1 so that pytest displays an interleaved log
- subprocess.run(['test/common/tap-cdp', f'{BUILDDIR}/test-server',
- sys.executable, '-m', 'cockpit.bridge', '--debug',
- built_test], check=True, stderr=subprocess.STDOUT, env=env)
+ monkeypatch.setenv('TZ', tz)
+ await test_browser('base1/test-timeformat.html')
diff --git a/test/pytest/webdriver_bidi.py b/test/pytest/webdriver_bidi.py
new file mode 100644
index 000000000000..83f2b2990a7d
--- /dev/null
+++ b/test/pytest/webdriver_bidi.py
@@ -0,0 +1,307 @@
+# This file is part of Cockpit.
+#
+# Copyright (C) 2024 Red Hat, Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import logging
+import socket
+from typing import AsyncIterator, Callable, Generic, Mapping, TypeVar, override
+
+import aiohttp
+from yarl import URL
+
+from cockpit.jsonutil import (
+ JsonObject,
+ JsonValue,
+ create_object,
+ get_dict,
+ get_int,
+ get_object,
+ get_objv,
+ get_str,
+ typechecked,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class WebdriverError(RuntimeError):
+ pass
+
+
+class ConsoleMessage:
+ def __init__(self, value: JsonObject):
+ print('SSSS', value)
+ self.level = get_str(value, 'level')
+ self.type = get_str(value, 'type')
+ self.timestamp = get_int(value, 'timestamp')
+ self.args = get_objv(value, 'args', dict)
+ self.text = get_str(value, 'text')
+
+
+# Return a port number that was free at the time of checking
+# It might be in use again by the time the function returns...
+def pick_a_port() -> int:
+ sock = socket.socket()
+ try:
+ sock.bind(('127.0.0.1', 0))
+ _ip, port = sock.getsockname()
+ return port
+ finally:
+ sock.close()
+
+
+T = TypeVar('T')
+
+
+class EventHandler:
+ def handle_event(self, msg: JsonObject) -> None:
+ raise NotImplementedError
+
+ def eof(self) -> None:
+ raise NotImplementedError
+
+
+class EventLog(Generic[T], EventHandler):
+ def __init__(self, ctor: Callable[[JsonObject], T]):
+ self.queue = asyncio.Queue[T | None]()
+ self.ctor = ctor
+
+ @override
+ def handle_event(self, msg: JsonObject) -> None:
+ self.queue.put_nowait(get_object(msg, 'params', self.ctor))
+
+ @override
+ def eof(self) -> None:
+ self.queue.put_nowait(None)
+
+ async def __aiter__(self) -> AsyncIterator[T]:
+ while True:
+ entry = await self.queue.get()
+ if entry is None:
+ return
+ yield entry
+
+
+class BrowsingContext:
+ def __init__(self, session: WebdriverSession, context: str):
+ self.session = session
+ self.context = context
+
+ async def command(self, command: str, **kwargs: JsonValue) -> JsonObject:
+ return await self.session.command(f'browsingContext.{command}', context=self.context, **kwargs)
+
+ async def navigate(self, url: str, **kwargs: JsonValue) -> JsonObject:
+ return await self.command('navigate', url=url, **kwargs)
+
+ async def evaluate(self, expression: str, /, *, await_promise: bool = False, **kwargs: JsonValue) -> JsonObject:
+ return await self.command(
+ "script.evaluate",
+ expression=expression,
+ awaitPromise=await_promise,
+ target={'context': self.context},
+ **kwargs
+ )
+
+
+class WebdriverSession:
+ def __init__(self, ws: aiohttp.ClientWebSocketResponse):
+ self.ws = ws
+ self.pending_commands = dict[int, asyncio.Future[JsonValue]]()
+ self.events = dict[str, EventHandler]()
+ self.last_tag = 1
+
+ def get_tag(self):
+ self.last_tag += 1
+ return self.last_tag
+
+ async def command(self, method, _params: JsonObject | None = None, /, **kwargs: JsonValue) -> JsonObject:
+ msg = {'id': self.get_tag(), 'method': method, 'params': create_object(_params, kwargs)}
+ logger.debug("ws ← %r", msg)
+ await self.ws.send_json(msg)
+ future = asyncio.get_running_loop().create_future()
+ self.pending_commands[msg['id']] = future
+ return await future
+
+ async def subscribe_event(self, name: str, ctor: Callable[[JsonObject], T]) -> EventLog[T]:
+ await self.command('session.subscribe', events=[name])
+ log = EventLog(ctor)
+ self.events[name] = log
+ return log
+
+ async def subscribe_console(self) -> EventLog[ConsoleMessage]:
+ return await self.subscribe_event('log.entryAdded', ConsoleMessage)
+
+ @contextlib.asynccontextmanager
+ async def create_context(self, context_type: str = 'tab', **kwargs: JsonValue) -> AsyncIterator[BrowsingContext]:
+ reply = await self.command("browsingContext.create", type=context_type, **kwargs)
+ context = get_str(reply, 'context')
+
+ yield BrowsingContext(self, context)
+
+ # TODO: tear down context
+
+ def reader_task_done(self, task: asyncio.Task[None]) -> None:
+ exc = task.exception() or EOFError
+
+ for future in self.pending_commands.values():
+ future.set_exception(exc)
+ self.pending_commands.clear()
+
+ for handler in self.events.values():
+ handler.eof()
+ self.events.clear()
+
+ async def reader_task(self) -> None:
+ logger.debug('reader_task(%r)', self)
+
+ async for ws_msg in self.ws:
+ logger.debug(' reader_task(%r) got %r', self, ws_msg)
+
+ if ws_msg.type == aiohttp.WSMsgType.TEXT:
+ logger.debug("ws TEXT → %r", ws_msg)
+ msg = typechecked(ws_msg.json(), Mapping)
+ logger.debug("ws TEXT → %r", msg)
+
+ msg_type = get_str(msg, 'type')
+ msg_id = get_int(msg, 'id', None)
+
+ if msg_id is not None:
+ try:
+ pending = self.pending_commands.pop(msg_id)
+ except KeyError:
+ logger.warning('Received non-pending command response %r', msg)
+ continue
+
+ logger.debug("ws_reader: resolving pending command %i", msg_id)
+ if msg_type == 'success':
+ pending.set_result(msg.get('result'))
+ else:
+ pending.set_exception(WebdriverError(f"{msg_type}: {msg['message']}"))
+
+ elif msg_type == 'event':
+ method = get_str(msg, 'method')
+ if method in self.events:
+ self.events[method].handle_event(msg)
+ else:
+ logger.warning("ws_reader: unhandled event %r", msg)
+
+ else:
+ logger.warning("ws_reader: unhandled message %r", msg)
+
+ elif ws_msg.type == aiohttp.WSMsgType.ERROR:
+ logger.error("BiDi failure: %s", ws_msg)
+ break
+
+
+class WebdriverDriver:
+ status: JsonObject | None = None
+
+ def __init__(self, url: URL):
+ self.url = url
+
+ async def poll_status(self, stdout: asyncio.StreamReader) -> None:
+ async with aiohttp.ClientSession() as session:
+ while self.status is None:
+ logger.debug('polling for status from %r', self.url)
+ try:
+ status = await session.get(self.url / 'status')
+ self.status = await status.json()
+ logger.debug(' status is %r', self.status)
+
+ except aiohttp.ClientError as exc:
+ # wait for output and try again
+ # we don't actually care about the output
+ # if we don't get any output for a long time, raise
+ logger.debug(' %s. waiting for more input.', exc)
+ await asyncio.wait_for(stdout.read(1024), 5.0)
+
+ @contextlib.asynccontextmanager
+ async def start_session(self) -> AsyncIterator[WebdriverSession]:
+ async with aiohttp.ClientSession() as client_session:
+ session_args = {"capabilities": {
+ "alwaysMatch": {
+ "webSocketUrl": True,
+ "goog:chromeOptions": {"binary": "/usr/bin/chromium-browser"},
+ }
+ }}
+
+ logging.debug('requesting new session %s %s', self.url, session_args)
+ response = await client_session.post(self.url / 'session', json=session_args)
+ reply = await response.json()
+ logging.debug(' session created: %r', reply)
+ session_info = get_dict(reply, 'value')
+ session_id = get_str(session_info, 'sessionId')
+ capabilities = get_dict(session_info, 'capabilities')
+ url = get_str(capabilities, 'webSocketUrl')
+
+ logging.debug('connecting to websocket %s', url)
+ async with client_session.ws_connect(url) as ws:
+ logging.debug(' connected %r', ws)
+ session = WebdriverSession(ws)
+ reader_task = asyncio.create_task(session.reader_task())
+ reader_task.add_done_callback(session.reader_task_done)
+ try:
+ yield session
+ logging.debug('delete session %r', session_id)
+ await client_session.delete(self.url / 'session' / session_id)
+ finally:
+ await reader_task
+
+ @classmethod
+ @contextlib.asynccontextmanager
+ async def connect(cls) -> AsyncIterator[WebdriverDriver]:
+ port = pick_a_port()
+ url = URL(f'http://127.0.0.1:{port}')
+
+ logger.debug('Trying to spawn driver for port %r', port)
+
+ process = await asyncio.create_subprocess_exec(
+ 'chromedriver', f'--port={port}',
+ stdout=asyncio.subprocess.PIPE)
+ assert process.stdout is not None
+ logger.debug('webdriver process %r started', process.pid)
+
+ try:
+ webdriver = cls(url)
+ await webdriver.poll_status(process.stdout)
+ yield webdriver
+
+ finally:
+ logger.debug('killing webdriver process %r', process.pid)
+ with contextlib.suppress(ProcessLookupError):
+ process.kill()
+ logger.debug('waiting for webdriver process %r', process.pid)
+ await process.wait()
+ logger.debug('webdriver process finished')
+
+
+async def main():
+ logging.basicConfig(level=logging.DEBUG)
+
+ async with WebdriverDriver.connect() as driver:
+ async with driver.start_session() as session:
+ async with session.create_context() as context:
+ await context.navigate('http://127.0.0.1:8080/')
+ await asyncio.sleep(100)
+
+
+if __name__ == '__main__':
+ asyncio.run(main())