Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

permessage-deflate decompression support in websocket #3263

Merged
merged 20 commits into from
May 20, 2024
1 change: 1 addition & 0 deletions lib/web/fetch/data-url.js
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ module.exports = {
collectAnHTTPQuotedString,
serializeAMimeType,
removeChars,
removeHTTPWhitespace,
minimizeSupportedMimeType,
HTTP_TOKEN_CODEPOINTS,
isomorphicDecode
Expand Down
22 changes: 13 additions & 9 deletions lib/web/websocket/connection.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const {
kReceivedClose,
kResponse
} = require('./symbols')
const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished } = require('./util')
const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished, parseExtensions } = require('./util')
const { channels } = require('../../core/diagnostics')
const { CloseEvent } = require('./events')
const { makeRequest } = require('../fetch/request')
Expand All @@ -31,7 +31,7 @@ try {
* @param {URL} url
* @param {string|string[]} protocols
* @param {import('./websocket').WebSocket} ws
* @param {(response: any) => void} onEstablish
* @param {(response: any, extensions: string[] | undefined) => void} onEstablish
* @param {Partial<import('../../types/websocket').WebSocketInit>} options
*/
function establishWebSocketConnection (url, protocols, client, ws, onEstablish, options) {
Expand Down Expand Up @@ -91,12 +91,11 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
// 9. Let permessageDeflate be a user-agent defined
// "permessage-deflate" extension header value.
// https://github.com/mozilla/gecko-dev/blob/ce78234f5e653a5d3916813ff990f053510227bc/netwerk/protocol/websocket/WebSocketChannel.cpp#L2673
// TODO: enable once permessage-deflate is supported
const permessageDeflate = '' // 'permessage-deflate; 15'
const permessageDeflate = options.node?.['client-extensions'] ?? ''

// 10. Append (`Sec-WebSocket-Extensions`, permessageDeflate) to
// request’s header list.
// request.headersList.append('sec-websocket-extensions', permessageDeflate)
request.headersList.append('sec-websocket-extensions', permessageDeflate)

// 11. Fetch request with useParallelQueue set to true, and
// processResponse given response being these steps:
Expand Down Expand Up @@ -167,10 +166,15 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
// header field to determine which extensions are requested is
// discussed in Section 9.1.)
const secExtension = response.headersList.get('Sec-WebSocket-Extensions')
let extensions

if (secExtension !== null && secExtension !== permessageDeflate) {
failWebsocketConnection(ws, 'Received different permessage-deflate than the one set.')
return
if (secExtension !== null) {
extensions = parseExtensions(secExtension)

if (!extensions.has(permessageDeflate)) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably wrong

failWebsocketConnection(ws, 'Sec-WebSocket-Extensions header does not match.')
return
}
}

// 6. If the response includes a |Sec-WebSocket-Protocol| header field
Expand Down Expand Up @@ -206,7 +210,7 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
})
}

onEstablish(response)
onEstablish(response, extensions)
}
})

Expand Down
71 changes: 71 additions & 0 deletions lib/web/websocket/permessage-deflate.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
'use strict'

const { createInflateRaw, Z_DEFAULT_WINDOWBITS } = require('node:zlib')
const { isValidClientWindowBits } = require('./util')

const tail = Buffer.from([0x00, 0x00, 0xff, 0xff])
const kBuffer = Symbol('kBuffer')
const kLength = Symbol('kLength')

class PerMessageDeflate {
/** @type {import('node:zlib').InflateRaw} */
#inflate

#options = {}

constructor (extensions) {
this.#options.clientNoContextTakeover = extensions.has('client_no_context_takeover')
this.#options.clientMaxWindowBits = extensions.get('client_max_window_bits')
}

decompress (chunk, fin, callback) {
// An endpoint uses the following algorithm to decompress a message.
// 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
// payload of the message.
// 2. Decompress the resulting data using DEFLATE.

if (!this.#inflate) {
let windowBits = Z_DEFAULT_WINDOWBITS

if (this.#options.clientMaxWindowBits) { // empty values default to Z_DEFAULT_WINDOWBITS
if (!isValidClientWindowBits(this.#options.clientMaxWindowBits)) {
callback(new Error('Invalid client_max_window_bits'))
return
}

windowBits = Number.parseInt(this.#options.clientMaxWindowBits)
KhafraDev marked this conversation as resolved.
Show resolved Hide resolved
}

this.#inflate = createInflateRaw({ windowBits })
this.#inflate[kBuffer] = []
this.#inflate[kLength] = 0

this.#inflate.on('data', (data) => {
this.#inflate[kBuffer].push(data)
this.#inflate[kLength] += data.length
})

this.#inflate.on('error', (err) => callback(err))
}

this.#inflate.write(chunk)
if (fin) {
this.#inflate.write(tail)
}

this.#inflate.flush(() => {
const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength])

this.#inflate[kBuffer].length = 0
this.#inflate[kLength] = 0

callback(null, full)

if (fin && this.#options.clientNoContextTakeover) {
this.#inflate.reset()
}
})
}
}

module.exports = { PerMessageDeflate }
62 changes: 48 additions & 14 deletions lib/web/websocket/receiver.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const {
} = require('./util')
const { WebsocketFrameSend } = require('./frame')
const { closeWebSocketConnection } = require('./connection')
const { PerMessageDeflate } = require('./permessage-deflate')

// This code was influenced by ws released under the MIT license.
// Copyright (c) 2011 Einar Otto Stangvik <einaros@gmail.com>
Expand All @@ -33,10 +34,18 @@ class ByteParser extends Writable {
#info = {}
#fragments = []

constructor (ws) {
/** @type {Map<string, PerMessageDeflate>} */
#extensions

constructor (ws, extensions) {
super()

this.ws = ws
this.#extensions = extensions == null ? new Map() : extensions

if (this.#extensions.has('permessage-deflate')) {
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions))
}
}

/**
Expand Down Expand Up @@ -91,7 +100,16 @@ class ByteParser extends Writable {
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
if (rsv1 !== 0 || rsv2 !== 0 || rsv3 !== 0) {
// This document allocates the RSV1 bit of the WebSocket header for
// PMCEs and calls the bit the "Per-Message Compressed" bit. On a
// WebSocket connection where a PMCE is in use, this bit indicates
// whether a message is compressed or not.
if (rsv1 !== 0 && !this.#extensions.has('permessage-deflate')) {
failWebsocketConnection(this.ws, 'Expected RSV1 to be clear.')
return
}

if (rsv2 !== 0 || rsv3 !== 0) {
failWebsocketConnection(this.ws, 'RSV1, RSV2, RSV3 must be clear')
return
}
Expand Down Expand Up @@ -122,7 +140,7 @@ class ByteParser extends Writable {
return
}

if (isContinuationFrame(opcode) && this.#fragments.length === 0) {
if (isContinuationFrame(opcode) && this.#fragments.length === 0 && !this.#info.compressed) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably wrong (?)

failWebsocketConnection(this.ws, 'Unexpected continuation frame')
return
}
Expand All @@ -138,6 +156,7 @@ class ByteParser extends Writable {

if (isTextBinaryFrame(opcode)) {
this.#info.binaryType = opcode
this.#info.compressed = rsv1 !== 0
}

this.#info.opcode = opcode
Expand Down Expand Up @@ -186,16 +205,32 @@ class ByteParser extends Writable {
if (isControlFrame(this.#info.opcode)) {
this.#loop = this.parseControlFrame(body)
} else {
this.#fragments.push(body)

// If the frame is not fragmented, a message has been received.
// If the frame is fragmented, it will terminate with a fin bit set
// and an opcode of 0 (continuation), therefore we handle that when
// parsing continuation frames, not here.
if (!this.#info.fragmented && this.#info.fin) {
const fullMessage = Buffer.concat(this.#fragments)
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
this.#fragments.length = 0
if (!this.#info.compressed) {
this.#fragments.push(body)

// If the frame is not fragmented, a message has been received.
// If the frame is fragmented, it will terminate with a fin bit set
// and an opcode of 0 (continuation), therefore we handle that when
// parsing continuation frames, not here.
if (!this.#info.fragmented && this.#info.fin) {
const fullMessage = Buffer.concat(this.#fragments)
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
this.#fragments.length = 0
}
} else {
this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => {
if (error) {
closeWebSocketConnection(this.ws, 1007, error.message, error.message.length)
return
}

websocketMessageReceived(this.ws, this.#info.binaryType, data)

this.#loop = true
this.run(callback)
})

this.#loop = false
}
}

Expand Down Expand Up @@ -333,7 +368,6 @@ class ByteParser extends Writable {
this.ws[kReadyState] = states.CLOSING
this.ws[kReceivedClose] = true

this.end()
return false
} else if (opcode === opcodes.PING) {
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
Expand Down
46 changes: 45 additions & 1 deletion lib/web/websocket/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const { kReadyState, kController, kResponse, kBinaryType, kWebSocketURL } = requ
const { states, opcodes } = require('./constants')
const { ErrorEvent, createFastMessageEvent } = require('./events')
const { isUtf8 } = require('node:buffer')
const { collectASequenceOfCodePointsFast, removeHTTPWhitespace } = require('../fetch/data-url')

/* globals Blob */

Expand Down Expand Up @@ -234,6 +235,47 @@ function isValidOpcode (opcode) {
return isTextBinaryFrame(opcode) || isContinuationFrame(opcode) || isControlFrame(opcode)
}

/**
* Parses a Sec-WebSocket-Extensions header value.
* @param {string} extensions
* @returns {Map<string, string>}
*/
function parseExtensions (extensions) {
KhafraDev marked this conversation as resolved.
Show resolved Hide resolved
const position = { position: 0 }
const extensionList = new Map()

while (position.position < extensions.length) {
const pair = collectASequenceOfCodePointsFast(';', extensions, position)
const [name, value = ''] = pair.split('=')

extensionList.set(
removeHTTPWhitespace(name, true, false),
removeHTTPWhitespace(value, false, true)
)

position.position++
}

return extensionList
}

/**
* @see https://www.rfc-editor.org/rfc/rfc7692#section-7.1.2.2
* @description "client-max-window-bits = 1*DIGIT"
* @param {string} value
*/
function isValidClientWindowBits (value) {
for (let i = 0; i < value.length; i++) {
const byte = value.charCodeAt(i)

if (byte < 0x30 || byte > 0x39) {
return false
}
}

return true
}

// https://nodejs.org/api/intl.html#detecting-internationalization-support
const hasIntl = typeof process.versions.icu === 'string'
const fatalDecoder = hasIntl ? new TextDecoder('utf-8', { fatal: true }) : undefined
Expand Down Expand Up @@ -265,5 +307,7 @@ module.exports = {
isControlFrame,
isContinuationFrame,
isTextBinaryFrame,
isValidOpcode
isValidOpcode,
parseExtensions,
isValidClientWindowBits
}
21 changes: 17 additions & 4 deletions lib/web/websocket/websocket.js
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class WebSocket extends EventTarget {
protocols,
client,
this,
(response) => this.#onConnectionEstablished(response),
(response, extensions) => this.#onConnectionEstablished(response, extensions),
options
)

Expand Down Expand Up @@ -458,12 +458,12 @@ class WebSocket extends EventTarget {
/**
* @see https://websockets.spec.whatwg.org/#feedback-from-the-protocol
*/
#onConnectionEstablished (response) {
#onConnectionEstablished (response, parsedExtensions) {
// processResponse is called when the "response’s header list has been received and initialized."
// once this happens, the connection is open
this[kResponse] = response

const parser = new ByteParser(this)
const parser = new ByteParser(this, parsedExtensions)
parser.on('drain', onParserDrain)
parser.on('error', onParserError.bind(this))

Expand Down Expand Up @@ -549,6 +549,14 @@ webidl.converters['DOMString or sequence<DOMString>'] = function (V, prefix, arg
return webidl.converters.DOMString(V, prefix, argument)
}

webidl.converters.WebSocketInitNodeOptions = webidl.dictionaryConverter([
KhafraDev marked this conversation as resolved.
Show resolved Hide resolved
{
key: 'client-extensions',
converter: webidl.converters.DOMString,
defaultValue: () => ''
}
])

// This implements the proposal made in https://github.com/whatwg/websockets/issues/42
webidl.converters.WebSocketInit = webidl.dictionaryConverter([
{
Expand All @@ -558,12 +566,17 @@ webidl.converters.WebSocketInit = webidl.dictionaryConverter([
},
{
key: 'dispatcher',
converter: (V) => V,
converter: webidl.converters.any,
defaultValue: () => getGlobalDispatcher()
},
{
key: 'headers',
converter: webidl.nullableConverter(webidl.converters.HeadersInit)
},
{
key: 'node',
converter: webidl.converters.WebSocketInitNodeOptions,
defaultValue: () => ({})
}
])

Expand Down
Loading
Loading