diff --git a/package.json b/package.json index f23569e6f..a1b53f092 100644 --- a/package.json +++ b/package.json @@ -134,6 +134,7 @@ "devDependencies": { "@commitlint/cli": "^18.4.4", "@commitlint/config-conventional": "^18.4.4", + "@fastify/websocket": "^8.3.1", "@open-draft/test-server": "^0.4.2", "@ossjs/release": "^0.8.0", "@playwright/test": "^1.40.1", @@ -143,6 +144,7 @@ "@types/glob": "^8.1.0", "@types/json-bigint": "^1.0.4", "@types/node": "18.x", + "@types/ws": "^8.5.10", "@typescript-eslint/eslint-plugin": "^5.11.0", "@typescript-eslint/parser": "^5.11.0", "@web/dev-server": "^0.1.38", @@ -158,6 +160,7 @@ "eslint-config-prettier": "^9.1.0", "eslint-plugin-prettier": "^5.1.3", "express": "^4.18.2", + "fastify": "^4.26.0", "fs-extra": "^11.2.0", "fs-teardown": "^0.3.0", "glob": "^10.3.10", @@ -174,7 +177,7 @@ "typescript": "^5.0.2", "undici": "^5.20.0", "url-loader": "^4.1.1", - "vitest": "^0.34.6", + "vitest": "^1.2.2", "vitest-environment-miniflare": "^2.14.1", "webpack": "^5.89.0", "webpack-http-server": "^0.5.0" diff --git a/src/core/handlers/WebSocketHandler.ts b/src/core/handlers/WebSocketHandler.ts index ffb398692..be0d4e136 100644 --- a/src/core/handlers/WebSocketHandler.ts +++ b/src/core/handlers/WebSocketHandler.ts @@ -14,7 +14,7 @@ type WebSocketHandlerParsedResult = { match: Match } -type WebSocketHandlerEventMap = { +export type WebSocketHandlerEventMap = { connection: [ args: { client: WebSocketClientConnection @@ -29,33 +29,14 @@ type WebSocketHandlerIncomingEvent = MessageEvent<{ server: WebSocketServerConnection }> -export const kRun = Symbol('run') +export const kEmitter = Symbol('kEmitter') +export const kRun = Symbol('kRun') export class WebSocketHandler { - public on: ( - event: K, - listener: (...args: WebSocketHandlerEventMap[K]) => void, - ) => void - - public off: ( - event: K, - listener: (...args: WebSocketHandlerEventMap[K]) => void, - ) => void - - public removeAllListeners: ( - event?: K, - ) => void - - protected emitter: Emitter + protected [kEmitter]: Emitter constructor(private readonly url: Path) { - this.emitter = new Emitter() - - // Forward some of the emitter API to the public API - // of the event handler. - this.on = this.emitter.on.bind(this.emitter) - this.off = this.emitter.off.bind(this.emitter) - this.removeAllListeners = this.emitter.removeAllListeners.bind(this.emitter) + this[kEmitter] = new Emitter() } public parse(args: { @@ -95,7 +76,7 @@ export class WebSocketHandler { // Emit the connection event on the handler. // This is what the developer adds listeners for. - this.emitter.emit('connection', { + this[kEmitter].emit('connection', { client: connection.client, server: connection.server, params: parsedResult.match.params || {}, diff --git a/src/core/ws/ws.ts b/src/core/ws/ws.ts index c44a3e469..516c87d16 100644 --- a/src/core/ws/ws.ts +++ b/src/core/ws/ws.ts @@ -1,4 +1,8 @@ -import { WebSocketHandler } from '../handlers/WebSocketHandler' +import { + WebSocketHandler, + kEmitter, + type WebSocketHandlerEventMap, +} from '../handlers/WebSocketHandler' import type { Path } from '../utils/matching/matchRequestUrl' import { webSocketInterceptor } from './webSocketInterceptor' @@ -11,7 +15,23 @@ import { webSocketInterceptor } from './webSocketInterceptor' */ function createWebSocketLinkHandler(url: Path) { webSocketInterceptor.apply() - return new WebSocketHandler(url) + + return { + on( + event: K, + listener: (...args: WebSocketHandlerEventMap[K]) => void, + ): WebSocketHandler { + const handler = new WebSocketHandler(url) + + // The "handleWebSocketEvent" function will invoke + // the "run()" method on the WebSocketHandler. + // If the handler matches, it will emit the "connection" + // event. Attach the user-defined listener to that event. + handler[kEmitter].on(event, listener) + + return handler + }, + } } export const ws = { diff --git a/src/node/SetupServerApi.ts b/src/node/SetupServerApi.ts index 6dd18e1e0..c632c1428 100644 --- a/src/node/SetupServerApi.ts +++ b/src/node/SetupServerApi.ts @@ -7,13 +7,14 @@ import { import { invariant } from 'outvariant' import { SetupApi } from '~/core/SetupApi' import { RequestHandler } from '~/core/handlers/RequestHandler' -import { LifeCycleEventsMap, SharedOptions } from '~/core/sharedOptions' -import { RequiredDeep } from '~/core/typeUtils' +import type { LifeCycleEventsMap, SharedOptions } from '~/core/sharedOptions' +import type { RequiredDeep } from '~/core/typeUtils' import { handleRequest } from '~/core/utils/handleRequest' import { devUtils } from '~/core/utils/internal/devUtils' import { mergeRight } from '~/core/utils/internal/mergeRight' -import { SetupServer } from './glossary' import type { WebSocketHandler } from '~/core/handlers/WebSocketHandler' +import { handleWebSocketEvent } from '~/core/utils/handleWebSocketEvent' +import type { SetupServer } from './glossary' const DEFAULT_LISTEN_OPTIONS: RequiredDeep = { onUnhandledRequest: 'warn', @@ -79,6 +80,9 @@ export class SetupServerApi ) }, ) + + // Handle outgoing WebSocket connections. + handleWebSocketEvent(this.currentHandlers) } public listen(options: Partial = {}): void { diff --git a/test/node/vitest.config.ts b/test/node/vitest.config.ts index 801c1edeb..53f3e1525 100644 --- a/test/node/vitest.config.ts +++ b/test/node/vitest.config.ts @@ -11,6 +11,8 @@ export default defineConfig({ dir: './test/node', globals: true, alias: { + 'vitest-environment-node-websocket': + './test/support/environments/vitest-environment-node-websocket', 'msw/node': path.resolve(LIB_DIR, 'node/index.mjs'), 'msw/native': path.resolve(LIB_DIR, 'native/index.mjs'), 'msw/browser': path.resolve(LIB_DIR, 'browser/index.mjs'), diff --git a/test/node/ws-api/ws.intercept.test.ts b/test/node/ws-api/ws.intercept.test.ts new file mode 100644 index 000000000..70db69c7f --- /dev/null +++ b/test/node/ws-api/ws.intercept.test.ts @@ -0,0 +1,109 @@ +/** + * @vitest-environment node-websocket + */ +import { ws } from 'msw' +import { setupServer } from 'msw/node' +import { WebSocketServer } from '../../support/WebSocketServer' +import { waitFor } from '../../support/waitFor' + +const server = setupServer() +const wsServer = new WebSocketServer() + +const service = ws.link('ws://*') + +beforeAll(async () => { + server.listen() + await wsServer.listen() +}) + +afterEach(() => { + wsServer.closeAllClients() + wsServer.removeAllListeners() +}) + +afterAll(async () => { + server.close() + await wsServer.close() +}) + +it('intercepts outgoing client text message', async () => { + const mockMessageListener = vi.fn() + const realConnectionListener = vi.fn() + + server.use( + service.on('connection', ({ client }) => { + client.addEventListener('message', mockMessageListener) + }), + ) + wsServer.on('connection', realConnectionListener) + + const ws = new WebSocket(wsServer.url) + ws.onopen = () => ws.send('hello') + + await waitFor(() => { + // Must intercept the outgoing client message event. + expect(mockMessageListener).toHaveBeenCalledTimes(1) + + const messageEvent = mockMessageListener.mock.calls[0][0] as MessageEvent + expect(messageEvent.type).toBe('message') + expect(messageEvent.data).toBe('hello') + expect(messageEvent.target).toBe(ws) + + // Must not connect to the actual server by default. + expect(realConnectionListener).not.toHaveBeenCalled() + }) +}) + +it('intercepts outgoing client Blob message', async () => { + const mockMessageListener = vi.fn() + const realConnectionListener = vi.fn() + + server.use( + service.on('connection', ({ client }) => { + client.addEventListener('message', mockMessageListener) + }), + ) + wsServer.on('connection', realConnectionListener) + + const ws = new WebSocket(wsServer.url) + ws.onopen = () => ws.send(new Blob(['hello'])) + + await waitFor(() => { + expect(mockMessageListener).toHaveBeenCalledTimes(1) + + const messageEvent = mockMessageListener.mock.calls[0][0] as MessageEvent + expect(messageEvent.type).toBe('message') + expect(messageEvent.data.size).toBe(5) + expect(messageEvent.target).toEqual(ws) + + // Must not connect to the actual server by default. + expect(realConnectionListener).not.toHaveBeenCalled() + }) +}) + +it('intercepts outgoing client ArrayBuffer message', async () => { + const mockMessageListener = vi.fn() + const realConnectionListener = vi.fn() + + server.use( + service.on('connection', ({ client }) => { + client.addEventListener('message', mockMessageListener) + }), + ) + wsServer.on('connection', realConnectionListener) + + const ws = new WebSocket(wsServer.url) + ws.onopen = () => ws.send(new TextEncoder().encode('hello')) + + await waitFor(() => { + expect(mockMessageListener).toHaveBeenCalledTimes(1) + + const messageEvent = mockMessageListener.mock.calls[0][0] as MessageEvent + expect(messageEvent.type).toBe('message') + expect(messageEvent.data).toEqual(new TextEncoder().encode('hello')) + expect(messageEvent.target).toEqual(ws) + + // Must not connect to the actual server by default. + expect(realConnectionListener).not.toHaveBeenCalled() + }) +}) diff --git a/test/support/WebSocketServer.ts b/test/support/WebSocketServer.ts new file mode 100644 index 000000000..546c14168 --- /dev/null +++ b/test/support/WebSocketServer.ts @@ -0,0 +1,55 @@ +import { invariant } from 'outvariant' +import { Emitter } from 'strict-event-emitter' +import fastify, { FastifyInstance } from 'fastify' +import fastifyWebSocket, { SocketStream } from '@fastify/websocket' + +type FastifySocket = SocketStream['socket'] + +type WebSocketEventMap = { + connection: [client: FastifySocket] +} + +export class WebSocketServer extends Emitter { + private _url?: string + private app: FastifyInstance + private clients: Set + + constructor() { + super() + this.clients = new Set() + + this.app = fastify() + this.app.register(fastifyWebSocket) + this.app.register(async (fastify) => { + fastify.get('/', { websocket: true }, (connection) => { + this.clients.add(connection.socket) + this.emit('connection', connection.socket) + }) + }) + } + + get url(): string { + invariant( + this._url, + 'Failed to get "url" on WebSocketServer: server is not running. Did you forget to "await server.listen()"?', + ) + return this._url + } + + public async listen(): Promise { + const address = await this.app.listen({ port: 0 }) + const url = new URL(address) + url.protocol = url.protocol.replace(/^http/, 'ws') + this._url = url.href + } + + public closeAllClients(): void { + this.clients.forEach((client) => { + client.close() + }) + } + + public async close(): Promise { + return this.app.close() + } +} diff --git a/test/support/environments/vitest-environment-node-websocket.ts b/test/support/environments/vitest-environment-node-websocket.ts new file mode 100644 index 000000000..4fe1b93ad --- /dev/null +++ b/test/support/environments/vitest-environment-node-websocket.ts @@ -0,0 +1,20 @@ +/** + * Node.js environment superset that has a global WebSocket API. + */ +import type { Environment } from 'vitest' +import { builtinEnvironments } from 'vitest/environments' +import { WebSocket } from 'undici' + +export default { + name: 'node-with-websocket', + transformMode: 'ssr', + async setup(global, options) { + const { teardown } = await builtinEnvironments.jsdom.setup(global, options) + + Reflect.set(globalThis, 'WebSocket', WebSocket) + + return { + teardown, + } + }, +}