diff --git a/Apollo.xcodeproj/project.pbxproj b/Apollo.xcodeproj/project.pbxproj index 91ec166978..1f898c1667 100644 --- a/Apollo.xcodeproj/project.pbxproj +++ b/Apollo.xcodeproj/project.pbxproj @@ -7,6 +7,8 @@ objects = { /* Begin PBXBuildFile section */ + 19E9F6AC26D58A9A003AB80E /* OperationMessageIdCreatorTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19E9F6AA26D58A92003AB80E /* OperationMessageIdCreatorTests.swift */; }; + 19E9F6B526D6BF25003AB80E /* OperationMessageIdCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19E9F6A826D5867E003AB80E /* OperationMessageIdCreator.swift */; }; 54DDB0921EA045870009DD99 /* InMemoryNormalizedCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 54DDB0911EA045870009DD99 /* InMemoryNormalizedCache.swift */; }; 5AC6CA4322AAF7B200B7C94D /* GraphQLHTTPMethod.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5AC6CA4222AAF7B200B7C94D /* GraphQLHTTPMethod.swift */; }; 5BB2C0232380836100774170 /* VersionNumberTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5BB2C0222380836100774170 /* VersionNumberTests.swift */; }; @@ -485,6 +487,8 @@ /* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ + 19E9F6A826D5867E003AB80E /* OperationMessageIdCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OperationMessageIdCreator.swift; sourceTree = ""; }; + 19E9F6AA26D58A92003AB80E /* OperationMessageIdCreatorTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OperationMessageIdCreatorTests.swift; sourceTree = ""; }; 54DDB0911EA045870009DD99 /* InMemoryNormalizedCache.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = InMemoryNormalizedCache.swift; sourceTree = ""; }; 5AC6CA4222AAF7B200B7C94D /* GraphQLHTTPMethod.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = GraphQLHTTPMethod.swift; sourceTree = ""; }; 5BB2C0222380836100774170 /* VersionNumberTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VersionNumberTests.swift; sourceTree = ""; }; @@ -1175,6 +1179,7 @@ children = ( E676C11F26CB05F90091215A /* DefaultImplementation */, 9B7BDA9823FDE94C00ACD198 /* WebSocketClient.swift */, + 19E9F6A826D5867E003AB80E /* OperationMessageIdCreator.swift */, 9B7BDA9723FDE94C00ACD198 /* OperationMessage.swift */, 9B7BDA9623FDE94C00ACD198 /* SplitNetworkTransport.swift */, 9B7BDA9423FDE94C00ACD198 /* WebSocketError.swift */, @@ -1724,6 +1729,7 @@ 9B7BDA8A23FDE92900ACD198 /* SplitNetworkTransportTests.swift */, D90F1AF92479DEE5007A1534 /* WebSocketTransportTests.swift */, DE181A3326C5D8D4000C0B9C /* CompressionTests.swift */, + 19E9F6AA26D58A92003AB80E /* OperationMessageIdCreatorTests.swift */, ); path = WebSocket; sourceTree = ""; @@ -2464,6 +2470,7 @@ 9B7BDA9B23FDE94C00ACD198 /* WebSocketError.swift in Sources */, 9B7BDA9D23FDE94C00ACD198 /* SplitNetworkTransport.swift in Sources */, 9B7BDA9E23FDE94C00ACD198 /* OperationMessage.swift in Sources */, + 19E9F6B526D6BF25003AB80E /* OperationMessageIdCreator.swift in Sources */, DE181A3626C5DE4F000C0B9C /* WebSocketStream.swift in Sources */, DE181A3226C5C401000C0B9C /* Compression.swift in Sources */, ); @@ -2651,6 +2658,7 @@ DED45DE9261B96B70086EF63 /* LoadQueryFromStoreTests.swift in Sources */, 9BF6C94325194DE2000D5B93 /* MultipartFormData+Testing.swift in Sources */, DE181A3426C5D8D4000C0B9C /* CompressionTests.swift in Sources */, + 19E9F6AC26D58A9A003AB80E /* OperationMessageIdCreatorTests.swift in Sources */, 9F21735B2568F3E200566121 /* PossiblyDeferredTests.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; diff --git a/Sources/ApolloWebSocket/OperationMessage.swift b/Sources/ApolloWebSocket/OperationMessage.swift index 0b8b308008..d0719262b8 100644 --- a/Sources/ApolloWebSocket/OperationMessage.swift +++ b/Sources/ApolloWebSocket/OperationMessage.swift @@ -12,6 +12,7 @@ final class OperationMessage { case connectionAck = "connection_ack" // Server -> Client case connectionError = "connection_error" // Server -> Client + case startAck = "start_ack" // Server -> Client case connectionKeepAlive = "ka" // Server -> Client case data = "data" // Server -> Client case error = "error" // Server -> Client diff --git a/Sources/ApolloWebSocket/OperationMessageIdCreator.swift b/Sources/ApolloWebSocket/OperationMessageIdCreator.swift new file mode 100644 index 0000000000..a34a0ebaf7 --- /dev/null +++ b/Sources/ApolloWebSocket/OperationMessageIdCreator.swift @@ -0,0 +1,26 @@ +import Foundation +#if !COCOAPODS +import ApolloUtils +#endif + +public protocol OperationMessageIdCreator { + func requestId() -> String +} + +// MARK: - Default Implementation + +public struct ApolloSequencedOperationMessageIdCreator: OperationMessageIdCreator { + private var sequenceNumberCounter = Atomic(0) + + // Internal init methods cannot be used in public methods + public init(startAt sequenceNumber: Int = 1) { + sequenceNumberCounter = Atomic(sequenceNumber) + } + + public func requestId() -> String { + let id = sequenceNumberCounter.value + _ = sequenceNumberCounter.increment() + + return "\(id)" + } +} diff --git a/Sources/ApolloWebSocket/WebSocketTransport.swift b/Sources/ApolloWebSocket/WebSocketTransport.swift index 7f7d643dc1..37dcbebd68 100644 --- a/Sources/ApolloWebSocket/WebSocketTransport.swift +++ b/Sources/ApolloWebSocket/WebSocketTransport.swift @@ -33,7 +33,8 @@ public class WebSocketTransport { let error: Atomic = Atomic(nil) let serializationFormat = JSONSerializationFormat.self private let requestBodyCreator: RequestBodyCreator - + private let operationMessageIdCreator: OperationMessageIdCreator + /// non-private for testing - you should not use this directly enum SocketConnectionState { case disconnected @@ -59,7 +60,6 @@ public class WebSocketTransport { private let sendOperationIdentifiers: Bool private let reconnectionInterval: TimeInterval private let allowSendingDuplicates: Bool - fileprivate let sequenceNumberCounter = Atomic(0) fileprivate var reconnected = false /// - NOTE: Setting this won't override immediately if the socket is still connected, only on reconnection. @@ -100,7 +100,8 @@ public class WebSocketTransport { allowSendingDuplicates: Bool = true, connectOnInit: Bool = true, connectingPayload: GraphQLMap? = [:], - requestBodyCreator: RequestBodyCreator = ApolloRequestBodyCreator()) { + requestBodyCreator: RequestBodyCreator = ApolloRequestBodyCreator(), + operationMessageIdCreator: OperationMessageIdCreator = ApolloSequencedOperationMessageIdCreator()) { self.websocket = websocket self.store = store self.connectingPayload = connectingPayload @@ -109,6 +110,7 @@ public class WebSocketTransport { self.reconnectionInterval = reconnectionInterval self.allowSendingDuplicates = allowSendingDuplicates self.requestBodyCreator = requestBodyCreator + self.operationMessageIdCreator = operationMessageIdCreator self.clientName = clientName self.clientVersion = clientVersion self.connectOnInit = connectOnInit @@ -143,9 +145,7 @@ public class WebSocketTransport { switch messageType { case .data, .error: - if - let id = parseHandler.id, - let responseHandler = subscribers[id] { + if let id = parseHandler.id, let responseHandler = subscribers[id] { if let payload = parseHandler.payload { responseHandler(.success(payload)) } else if let error = parseHandler.error { @@ -178,7 +178,8 @@ public class WebSocketTransport { acked = true writeQueue() - case .connectionKeepAlive: + case .connectionKeepAlive, + .startAck: writeQueue() case .connectionInit, @@ -267,22 +268,21 @@ public class WebSocketTransport { sendOperationIdentifiers: self.sendOperationIdentifiers, sendQueryDocument: true, autoPersistQuery: false) - let sequenceNumber = "\(sequenceNumberCounter.increment())" - - guard let message = OperationMessage(payload: body, id: sequenceNumber).rawMessage else { + let identifier = operationMessageIdCreator.requestId() + guard let message = OperationMessage(payload: body, id: identifier).rawMessage else { return nil } processingQueue.async { self.write(message) - self.subscribers[sequenceNumber] = resultHandler + self.subscribers[identifier] = resultHandler if operation.operationType == .subscription { - self.subscriptions[sequenceNumber] = message + self.subscriptions[identifier] = message } } - return sequenceNumber + return identifier } public func unsubscribe(_ subscriptionId: String) { diff --git a/Tests/ApolloTests/WebSocket/OperationMessageIdCreatorTests.swift b/Tests/ApolloTests/WebSocket/OperationMessageIdCreatorTests.swift new file mode 100644 index 0000000000..6c4381eaeb --- /dev/null +++ b/Tests/ApolloTests/WebSocket/OperationMessageIdCreatorTests.swift @@ -0,0 +1,31 @@ +import XCTest +@testable import ApolloWebSocket +import UploadAPI + +class OperationMessageIdCreatorTests: XCTestCase { + struct CustomOperationMessageIdCreator: OperationMessageIdCreator { + func requestId() -> String { + return "12345678" + } + } + + // MARK: - Tests + + func testOperationMessageIdCreatorWithApolloOperationMessageIdCreator() { + let apolloOperationMessageIdCreator = ApolloSequencedOperationMessageIdCreator(startAt: 5) + + let firstId = apolloOperationMessageIdCreator.requestId() + let secondId = apolloOperationMessageIdCreator.requestId() + + XCTAssertEqual(firstId, "5") + XCTAssertEqual(secondId, "6") + } + + func testOperationMessageIdCreatorWithCustomOperationMessageIdCreator() { + let customOperationMessageIdCreator = CustomOperationMessageIdCreator() + + let id = customOperationMessageIdCreator.requestId() + + XCTAssertEqual(id, "12345678") + } +} diff --git a/Tests/ApolloTests/WebSocket/WebSocketTests.swift b/Tests/ApolloTests/WebSocket/WebSocketTests.swift index 1464153362..a6d610aa99 100644 --- a/Tests/ApolloTests/WebSocket/WebSocketTests.swift +++ b/Tests/ApolloTests/WebSocket/WebSocketTests.swift @@ -18,6 +18,12 @@ class WebSocketTests: XCTestCase { var client: ApolloClient! var websocket: MockWebSocket! + struct CustomOperationMessageIdCreator: OperationMessageIdCreator { + func requestId() -> String { + return "12345678" + } + } + override func setUp() { super.setUp() @@ -122,4 +128,42 @@ class WebSocketTests: XCTestCase { waitForExpectations(timeout: 2, handler: nil) } + + func testSingleSubscriptionWithCustomOperationMessageIdCreator() throws { + let expectation = self.expectation(description: "Single Subscription with Custom Operation Message Id Creator") + + let store = ApolloStore() + let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url)) + networkTransport = WebSocketTransport(websocket: websocket, store: store, operationMessageIdCreator: CustomOperationMessageIdCreator()) + client = ApolloClient(networkTransport: networkTransport!, store: store) + + client.subscribe(subscription: ReviewAddedSubscription()) { result in + defer { expectation.fulfill() } + switch result { + case .success(let graphQLResult): + XCTAssertEqual(graphQLResult.data?.reviewAdded?.stars, 5) + case .failure(let error): + XCTFail("Unexpected error: \(error)") + } + } + + let message : GraphQLMap = [ + "type": "data", + "id": "12345678", // subscribing on id = 12345678 from custom operation id + "payload": [ + "data": [ + "reviewAdded": [ + "__typename": "ReviewAdded", + "episode": "JEDI", + "stars": 5, + "commentary": "A great movie" + ] + ] + ] + ] + + networkTransport.write(message: message) + + waitForExpectations(timeout: 2, handler: nil) + } }