diff --git a/Apollo.xcodeproj/project.pbxproj b/Apollo.xcodeproj/project.pbxproj index 37b6205972..8b94e03651 100644 --- a/Apollo.xcodeproj/project.pbxproj +++ b/Apollo.xcodeproj/project.pbxproj @@ -199,6 +199,7 @@ C35D43C622DDE28D00BCBABE /* a.txt in Resources */ = {isa = PBXBuildFile; fileRef = C304EBD322DDC7B200748F72 /* a.txt */; }; C377CCA922D798BD00572E03 /* GraphQLFile.swift in Sources */ = {isa = PBXBuildFile; fileRef = C377CCA822D798BD00572E03 /* GraphQLFile.swift */; }; C377CCAB22D7992E00572E03 /* MultipartFormData.swift in Sources */ = {isa = PBXBuildFile; fileRef = C377CCAA22D7992E00572E03 /* MultipartFormData.swift */; }; + D90F1AFB2479E57A007A1534 /* WebSocketTransportTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D90F1AF92479DEE5007A1534 /* WebSocketTransportTests.swift */; }; E86D8E05214B32FD0028EFE1 /* JSONTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E86D8E03214B32DA0028EFE1 /* JSONTests.swift */; }; F16D083C21EF6F7300C458B8 /* QueryFromJSONBuildingTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F16D083B21EF6F7300C458B8 /* QueryFromJSONBuildingTests.swift */; }; F82E62E122BCD223000C311B /* AutomaticPersistedQueriesTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F82E62E022BCD223000C311B /* AutomaticPersistedQueriesTests.swift */; }; @@ -615,6 +616,7 @@ C35D43BF22DDD3C100BCBABE /* c.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = c.txt; sourceTree = ""; }; C377CCA822D798BD00572E03 /* GraphQLFile.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GraphQLFile.swift; sourceTree = ""; }; C377CCAA22D7992E00572E03 /* MultipartFormData.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MultipartFormData.swift; sourceTree = ""; }; + D90F1AF92479DEE5007A1534 /* WebSocketTransportTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WebSocketTransportTests.swift; sourceTree = ""; }; E86D8E03214B32DA0028EFE1 /* JSONTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = JSONTests.swift; sourceTree = ""; }; F16D083B21EF6F7300C458B8 /* QueryFromJSONBuildingTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = QueryFromJSONBuildingTests.swift; sourceTree = ""; }; F82E62E022BCD223000C311B /* AutomaticPersistedQueriesTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AutomaticPersistedQueriesTests.swift; sourceTree = ""; }; @@ -852,6 +854,7 @@ 9B7BDA8A23FDE92900ACD198 /* SplitNetworkTransportTests.swift */, 9B7BDA8823FDE92900ACD198 /* StarWarsSubscriptionTests.swift */, 9B7BDA8C23FDE92900ACD198 /* StarWarsWebSocketTests.swift */, + D90F1AF92479DEE5007A1534 /* WebSocketTransportTests.swift */, 9B7BDA8B23FDE92900ACD198 /* Info.plist */, ); path = ApolloWebsocketTests; @@ -1974,6 +1977,7 @@ 9B7BDA9223FDE92A00ACD198 /* StarWarsWebSocketTests.swift in Sources */, 9B7BDA9023FDE92A00ACD198 /* SplitNetworkTransportTests.swift in Sources */, 9B7BDA8E23FDE92A00ACD198 /* StarWarsSubscriptionTests.swift in Sources */, + D90F1AFB2479E57A007A1534 /* WebSocketTransportTests.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/Sources/ApolloWebSocket/WebSocketTransport.swift b/Sources/ApolloWebSocket/WebSocketTransport.swift index 449fdb869c..d73049ccde 100644 --- a/Sources/ApolloWebSocket/WebSocketTransport.swift +++ b/Sources/ApolloWebSocket/WebSocketTransport.swift @@ -299,6 +299,29 @@ public class WebSocketTransport { self.subscriptions.removeValue(forKey: subscriptionId) } } + + public func updateHeaderValues(_ values: [String: String?]) { + for (key, value) in values { + self.websocket.request.setValue(value, forHTTPHeaderField: key) + } + + self.reconnectWebSocket() + } + + public func updateConnectingPayload(_ payload: GraphQLMap) { + self.connectingPayload = payload + self.reconnectWebSocket() + } + + private func reconnectWebSocket() { + let oldReconnectValue = reconnect.value + self.reconnect.value = false + + self.websocket.disconnect() + self.websocket.connect() + + reconnect.value = oldReconnectValue + } } // MARK: - HTTPNetworkTransport conformance diff --git a/Tests/ApolloWebsocketTests/WebSocketTransportTests.swift b/Tests/ApolloWebsocketTests/WebSocketTransportTests.swift new file mode 100644 index 0000000000..3234155443 --- /dev/null +++ b/Tests/ApolloWebsocketTests/WebSocketTransportTests.swift @@ -0,0 +1,66 @@ +import XCTest +import Apollo +import Starscream +@testable import ApolloWebSocket + +class WebSocketTransportTests: XCTestCase { + + private let mockSocketURL = URL(string: "http://localhost/dummy_url")! + private var webSocketTransport: WebSocketTransport! + + func testUpdateHeaderValues() { + var request = URLRequest(url: mockSocketURL) + request.addValue("OldToken", forHTTPHeaderField: "Authorization") + + self.webSocketTransport = WebSocketTransport(request: request) + + self.webSocketTransport.updateHeaderValues(["Authorization": "UpdatedToken"]) + + XCTAssertEqual(self.webSocketTransport.websocket.request.allHTTPHeaderFields?["Authorization"], "UpdatedToken") + } + + func testUpdateConnectingPayload() { + WebSocketTransport.provider = MockWebSocket.self + + self.webSocketTransport = WebSocketTransport(request: URLRequest(url: mockSocketURL), + connectingPayload: ["Authorization": "OldToken"]) + + let mockWebSocketDelegate = MockWebSocketDelegate() + + let mockWebSocket = self.webSocketTransport.websocket as? MockWebSocket + mockWebSocket?.isConnected = true + mockWebSocket?.delegate = mockWebSocketDelegate + + let exp = expectation(description: "Waiting for reconnect") + + mockWebSocketDelegate.didReceiveMessage = { message in + let json = try? JSONSerializationFormat.deserialize(data: message.data(using: .utf8)!) as? JSONObject + guard let payload = json?["payload"] as? JSONObject, (json?["type"] as? String) == "connection_init" else { + return + } + + XCTAssertEqual(payload["Authorization"] as? String, "UpdatedToken") + exp.fulfill() + } + + self.webSocketTransport.updateConnectingPayload(["Authorization": "UpdatedToken"]) + self.webSocketTransport.initServer() + + waitForExpectations(timeout: 3, handler: nil) + } +} + +private final class MockWebSocketDelegate: WebSocketDelegate { + + var didReceiveMessage: ((String) -> Void)? + + func websocketDidConnect(socket: WebSocketClient) { } + + func websocketDidDisconnect(socket: WebSocketClient, error: Error?) { } + + func websocketDidReceiveMessage(socket: WebSocketClient, text: String) { + didReceiveMessage?(text) + } + + func websocketDidReceiveData(socket: WebSocketClient, data: Data) { } +}