Skip to content

Commit

Permalink
Merge pull request #880 from aivcec/fix/subscriptions-data-races
Browse files Browse the repository at this point in the history
Fixing data races in subscriptions
  • Loading branch information
designatednerd authored Nov 11, 2019
2 parents f0ceece + 33c3649 commit 08ae91a
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 33 deletions.
12 changes: 12 additions & 0 deletions ApolloWebSocket.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
9F28B6D520720F2F00144A00 /* Apollo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D420720F2F00144A00 /* Apollo.framework */; };
9F28B6D920720FD200144A00 /* ApolloTestSupport.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */; };
9F28B6DB2072101200144A00 /* StarWarsAPI.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6DA2072101200144A00 /* StarWarsAPI.framework */; };
D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */ = {isa = PBXBuildFile; fileRef = D1ACF61B23715AF30042E200 /* Atomic.swift */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand Down Expand Up @@ -73,6 +74,7 @@
9F28B6D420720F2F00144A00 /* Apollo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = Apollo.framework; sourceTree = BUILT_PRODUCTS_DIR; };
9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = ApolloTestSupport.framework; sourceTree = BUILT_PRODUCTS_DIR; };
9F28B6DA2072101200144A00 /* StarWarsAPI.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StarWarsAPI.framework; sourceTree = BUILT_PRODUCTS_DIR; };
D1ACF61B23715AF30042E200 /* Atomic.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Atomic.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -141,6 +143,7 @@
9B1CCDE223611606007C9032 /* WebSocketTask.swift */,
7270746B206D111A00C131F6 /* WebSocketTransport.swift */,
7270746C206D111A00C131F6 /* Info.plist */,
D1ACF61923715AF30042E200 /* Utilities */,
);
name = ApolloWebSocket;
path = Sources/ApolloWebSocket;
Expand Down Expand Up @@ -181,6 +184,14 @@
name = Products;
sourceTree = "<group>";
};
D1ACF61923715AF30042E200 /* Utilities */ = {
isa = PBXGroup;
children = (
D1ACF61B23715AF30042E200 /* Atomic.swift */,
);
path = Utilities;
sourceTree = "<group>";
};
/* End PBXGroup section */

/* Begin PBXNativeTarget section */
Expand Down Expand Up @@ -310,6 +321,7 @@
9B1CCDDF236110C3007C9032 /* WebSocketError.swift in Sources */,
7270746D206D111A00C131F6 /* SplitNetworkTransport.swift in Sources */,
9B1CCDE323611606007C9032 /* WebSocketTask.swift in Sources */,
D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */,
7270746E206D111A00C131F6 /* WebSocketTransport.swift in Sources */,
9B1CCDE123611580007C9032 /* OperationMessage.swift in Sources */,
9B1CCDDB23610CDC007C9032 /* ApolloWebSocket.swift in Sources */,
Expand Down
3 changes: 3 additions & 0 deletions Sources/ApolloWebSocket/ApolloWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ public protocol ApolloWebSocketClient: WebSocketClient {

/// The URLRequest used on connection.
var request: URLRequest { get set }

/// Queue where the callbacks are executed
var callbackQueue: DispatchQueue { get set }
}

// MARK: - WebSocket
Expand Down
36 changes: 36 additions & 0 deletions Sources/ApolloWebSocket/Utilities/Atomic.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import Foundation

class Atomic<T> {
private let lock = NSLock()
private var _value: T

init(_ value: T) {
_value = value
}

var value: T {
get {
lock.lock()
defer { lock.unlock() }

return _value
}
set {
lock.lock()
defer { lock.unlock() }

_value = newValue
}
}
}

extension Atomic where T == Int {

func increment() -> T {
lock.lock()
defer { lock.unlock() }

_value += 1
return _value
}
}
68 changes: 38 additions & 30 deletions Sources/ApolloWebSocket/WebSocketTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public class WebSocketTransport {
public static var provider: ApolloWebSocketClient.Type = ApolloWebSocket.self
public weak var delegate: WebSocketTransportDelegate?

var reconnect = false
let reconnect: Atomic<Bool> = Atomic(false)
var websocket: ApolloWebSocketClient
var error: Error? = nil
let error: Atomic<Error?> = Atomic(nil)
let serializationFormat = JSONSerializationFormat.self
private let requestCreator: RequestCreator

Expand All @@ -40,10 +40,11 @@ public class WebSocketTransport {

private var subscribers = [String: (Result<JSONObject, Error>) -> Void]()
private var subscriptions : [String: String] = [:]
private let processingQueue = DispatchQueue(label: "com.apollographql.WebSocketTransport")

private let sendOperationIdentifiers: Bool
private let reconnectionInterval: TimeInterval
fileprivate var sequenceNumber = 0
fileprivate let sequenceNumberCounter = Atomic<Int>(0)
fileprivate var reconnected = false

/// NOTE: Setting this won't override immediately if the socket is still connected, only on reconnection.
Expand Down Expand Up @@ -87,6 +88,7 @@ public class WebSocketTransport {
self.websocket.request.setValue(self.clientVersion, forHTTPHeaderField: WebSocketTransport.headerFieldNameClientVersion)
self.websocket.delegate = self
self.websocket.connect()
self.websocket.callbackQueue = processingQueue
}

public func isConnected() -> Bool {
Expand Down Expand Up @@ -174,7 +176,7 @@ public class WebSocketTransport {
}

public func initServer(reconnect: Bool = true) {
self.reconnect = reconnect
self.reconnect.value = reconnect
self.acked = false

if let str = OperationMessage(payload: self.connectingPayload, type: .connectionInit).rawMessage {
Expand All @@ -184,12 +186,17 @@ public class WebSocketTransport {
}

public func closeConnection() {
self.reconnect = false
if let str = OperationMessage(type: .connectionTerminate).rawMessage {
write(str)
self.reconnect.value = false

let str = OperationMessage(type: .connectionTerminate).rawMessage
processingQueue.async {
if let str = str {
self.write(str)
}

self.queue.removeAll()
self.subscriptions.removeAll()
}
self.queue.removeAll()
self.subscriptions.removeAll()
}

private func write(_ str: String, force forced: Bool = false, id: Int? = nil) {
Expand All @@ -213,43 +220,44 @@ public class WebSocketTransport {
websocket.delegate = nil
}

private func nextSequenceNumber() -> Int {
sequenceNumber += 1
return sequenceNumber
}

func sendHelper<Operation: GraphQLOperation>(operation: Operation, resultHandler: @escaping (_ result: Result<JSONObject, Error>) -> Void) -> String? {
let body = requestCreator.requestBody(for: operation, sendOperationIdentifiers: self.sendOperationIdentifiers)
let sequenceNumber = "\(nextSequenceNumber())"
let sequenceNumber = "\(sequenceNumberCounter.increment())"

guard let message = OperationMessage(payload: body, id: sequenceNumber).rawMessage else {
return nil
}

write(message)

processingQueue.async {
self.write(message)

subscribers[sequenceNumber] = resultHandler
if operation.operationType == .subscription {
subscriptions[sequenceNumber] = message
self.subscribers[sequenceNumber] = resultHandler
if operation.operationType == .subscription {
self.subscriptions[sequenceNumber] = message
}
}

return sequenceNumber
}

public func unsubscribe(_ subscriptionId: String) {
if let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage {
write(str)
let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage

processingQueue.async {
if let str = str {
self.write(str)
}
self.subscribers.removeValue(forKey: subscriptionId)
self.subscriptions.removeValue(forKey: subscriptionId)
}
subscribers.removeValue(forKey: subscriptionId)
subscriptions.removeValue(forKey: subscriptionId)
}
}

// MARK: - HTTPNetworkTransport conformance

extension WebSocketTransport: NetworkTransport {
public func send<Operation>(operation: Operation, completionHandler: @escaping (_ result: Result<GraphQLResponse<Operation>,Error>) -> Void) -> Cancellable {
if let error = self.error {
if let error = self.error.value {
completionHandler(.failure(error))
return EmptyCancellable()
}
Expand All @@ -271,7 +279,7 @@ extension WebSocketTransport: NetworkTransport {
extension WebSocketTransport: WebSocketDelegate {

public func websocketDidConnect(socket: WebSocketClient) {
self.error = nil
self.error.value = nil
initServer()
if reconnected {
self.delegate?.webSocketTransportDidReconnect(self)
Expand All @@ -290,16 +298,16 @@ extension WebSocketTransport: WebSocketDelegate {
public func websocketDidDisconnect(socket: WebSocketClient, error: Error?) {
// report any error to all subscribers
if let error = error {
self.error = WebSocketError(payload: nil, error: error, kind: .networkError)
self.error.value = WebSocketError(payload: nil, error: error, kind: .networkError)
self.notifyErrorAllHandlers(error)
} else {
self.error = nil
self.error.value = nil
}

self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error)
self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error.value)
acked = false // need new connect and ack before sending

if reconnect {
if reconnect.value {
DispatchQueue.main.asyncAfter(deadline: .now() + reconnectionInterval) {
self.websocket.connect()
}
Expand Down
13 changes: 12 additions & 1 deletion Tests/ApolloWebsocketTests/MockWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import Starscream
@testable import ApolloWebSocket

class MockWebSocket: ApolloWebSocketClient {

var callbackQueue: DispatchQueue = DispatchQueue.main

var pongDelegate: WebSocketPongDelegate?
var request: URLRequest

Expand All @@ -15,8 +18,16 @@ class MockWebSocket: ApolloWebSocketClient {
self.request = URLRequest(url: URL(string: "http://localhost:8080")!)
}

open func reportDidConnect() {
callbackQueue.async {
self.delegate?.websocketDidConnect(socket: self)
}
}

open func write(string: String, completion: (() -> ())?) {
delegate?.websocketDidReceiveMessage(socket: self, text: string)
callbackQueue.async {
self.delegate?.websocketDidReceiveMessage(socket: self, text: string)
}
}

open func write(data: Data, completion: (() -> ())?) {
Expand Down
Loading

0 comments on commit 08ae91a

Please sign in to comment.