Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing data races in subscriptions #880

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ApolloWebSocket.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
9F28B6D520720F2F00144A00 /* Apollo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D420720F2F00144A00 /* Apollo.framework */; };
9F28B6D920720FD200144A00 /* ApolloTestSupport.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */; };
9F28B6DB2072101200144A00 /* StarWarsAPI.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6DA2072101200144A00 /* StarWarsAPI.framework */; };
D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */ = {isa = PBXBuildFile; fileRef = D1ACF61B23715AF30042E200 /* Atomic.swift */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand Down Expand Up @@ -73,6 +74,7 @@
9F28B6D420720F2F00144A00 /* Apollo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = Apollo.framework; sourceTree = BUILT_PRODUCTS_DIR; };
9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = ApolloTestSupport.framework; sourceTree = BUILT_PRODUCTS_DIR; };
9F28B6DA2072101200144A00 /* StarWarsAPI.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StarWarsAPI.framework; sourceTree = BUILT_PRODUCTS_DIR; };
D1ACF61B23715AF30042E200 /* Atomic.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Atomic.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -141,6 +143,7 @@
9B1CCDE223611606007C9032 /* WebSocketTask.swift */,
7270746B206D111A00C131F6 /* WebSocketTransport.swift */,
7270746C206D111A00C131F6 /* Info.plist */,
D1ACF61923715AF30042E200 /* Utilities */,
);
name = ApolloWebSocket;
path = Sources/ApolloWebSocket;
Expand Down Expand Up @@ -181,6 +184,14 @@
name = Products;
sourceTree = "<group>";
};
D1ACF61923715AF30042E200 /* Utilities */ = {
isa = PBXGroup;
children = (
D1ACF61B23715AF30042E200 /* Atomic.swift */,
);
path = Utilities;
sourceTree = "<group>";
};
/* End PBXGroup section */

/* Begin PBXNativeTarget section */
Expand Down Expand Up @@ -310,6 +321,7 @@
9B1CCDDF236110C3007C9032 /* WebSocketError.swift in Sources */,
7270746D206D111A00C131F6 /* SplitNetworkTransport.swift in Sources */,
9B1CCDE323611606007C9032 /* WebSocketTask.swift in Sources */,
D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */,
7270746E206D111A00C131F6 /* WebSocketTransport.swift in Sources */,
9B1CCDE123611580007C9032 /* OperationMessage.swift in Sources */,
9B1CCDDB23610CDC007C9032 /* ApolloWebSocket.swift in Sources */,
Expand Down
3 changes: 3 additions & 0 deletions Sources/ApolloWebSocket/ApolloWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ public protocol ApolloWebSocketClient: WebSocketClient {

/// The URLRequest used on connection.
var request: URLRequest { get set }

/// Queue where the callbacks are executed
var callbackQueue: DispatchQueue { get set }
designatednerd marked this conversation as resolved.
Show resolved Hide resolved
}

// MARK: - WebSocket
Expand Down
36 changes: 36 additions & 0 deletions Sources/ApolloWebSocket/Utilities/Atomic.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import Foundation

class Atomic<T> {
private let lock = NSLock()
private var _value: T

init(_ value: T) {
_value = value
}

var value: T {
get {
lock.lock()
defer { lock.unlock() }

return _value
}
set {
lock.lock()
defer { lock.unlock() }

_value = newValue
}
}
}

extension Atomic where T == Int {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it!


func increment() -> T {
lock.lock()
defer { lock.unlock() }

_value += 1
return _value
}
}
68 changes: 38 additions & 30 deletions Sources/ApolloWebSocket/WebSocketTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public class WebSocketTransport {
public static var provider: ApolloWebSocketClient.Type = ApolloWebSocket.self
public weak var delegate: WebSocketTransportDelegate?

var reconnect = false
let reconnect: Atomic<Bool> = Atomic(false)
var websocket: ApolloWebSocketClient
var error: Error? = nil
let error: Atomic<Error?> = Atomic(nil)
let serializationFormat = JSONSerializationFormat.self
private let requestCreator: RequestCreator

Expand All @@ -40,10 +40,11 @@ public class WebSocketTransport {

private var subscribers = [String: (Result<JSONObject, Error>) -> Void]()
private var subscriptions : [String: String] = [:]
private let processingQueue = DispatchQueue(label: "com.apollographql.WebSocketTransport")

private let sendOperationIdentifiers: Bool
private let reconnectionInterval: TimeInterval
fileprivate var sequenceNumber = 0
fileprivate let sequenceNumberCounter = Atomic<Int>(0)
fileprivate var reconnected = false

/// NOTE: Setting this won't override immediately if the socket is still connected, only on reconnection.
Expand Down Expand Up @@ -87,6 +88,7 @@ public class WebSocketTransport {
self.websocket.request.setValue(self.clientVersion, forHTTPHeaderField: WebSocketTransport.headerFieldNameClientVersion)
self.websocket.delegate = self
self.websocket.connect()
self.websocket.callbackQueue = processingQueue
}

public func isConnected() -> Bool {
Expand Down Expand Up @@ -174,7 +176,7 @@ public class WebSocketTransport {
}

public func initServer(reconnect: Bool = true) {
self.reconnect = reconnect
self.reconnect.value = reconnect
self.acked = false

if let str = OperationMessage(payload: self.connectingPayload, type: .connectionInit).rawMessage {
Expand All @@ -184,12 +186,17 @@ public class WebSocketTransport {
}

public func closeConnection() {
self.reconnect = false
if let str = OperationMessage(type: .connectionTerminate).rawMessage {
write(str)
self.reconnect.value = false

let str = OperationMessage(type: .connectionTerminate).rawMessage
processingQueue.async {
if let str = str {
self.write(str)
}

self.queue.removeAll()
self.subscriptions.removeAll()
}
self.queue.removeAll()
self.subscriptions.removeAll()
}

private func write(_ str: String, force forced: Bool = false, id: Int? = nil) {
Expand All @@ -213,43 +220,44 @@ public class WebSocketTransport {
websocket.delegate = nil
}

private func nextSequenceNumber() -> Int {
sequenceNumber += 1
return sequenceNumber
}

func sendHelper<Operation: GraphQLOperation>(operation: Operation, resultHandler: @escaping (_ result: Result<JSONObject, Error>) -> Void) -> String? {
let body = requestCreator.requestBody(for: operation, sendOperationIdentifiers: self.sendOperationIdentifiers)
let sequenceNumber = "\(nextSequenceNumber())"
let sequenceNumber = "\(sequenceNumberCounter.increment())"

guard let message = OperationMessage(payload: body, id: sequenceNumber).rawMessage else {
return nil
}

write(message)

processingQueue.async {
self.write(message)

subscribers[sequenceNumber] = resultHandler
if operation.operationType == .subscription {
subscriptions[sequenceNumber] = message
self.subscribers[sequenceNumber] = resultHandler
if operation.operationType == .subscription {
self.subscriptions[sequenceNumber] = message
}
}

return sequenceNumber
}

public func unsubscribe(_ subscriptionId: String) {
if let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage {
write(str)
let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage

processingQueue.async {
if let str = str {
self.write(str)
}
self.subscribers.removeValue(forKey: subscriptionId)
self.subscriptions.removeValue(forKey: subscriptionId)
}
subscribers.removeValue(forKey: subscriptionId)
subscriptions.removeValue(forKey: subscriptionId)
}
}

// MARK: - HTTPNetworkTransport conformance

extension WebSocketTransport: NetworkTransport {
public func send<Operation>(operation: Operation, completionHandler: @escaping (_ result: Result<GraphQLResponse<Operation>,Error>) -> Void) -> Cancellable {
if let error = self.error {
if let error = self.error.value {
completionHandler(.failure(error))
return EmptyCancellable()
}
Expand All @@ -271,7 +279,7 @@ extension WebSocketTransport: NetworkTransport {
extension WebSocketTransport: WebSocketDelegate {

public func websocketDidConnect(socket: WebSocketClient) {
self.error = nil
self.error.value = nil
initServer()
if reconnected {
self.delegate?.webSocketTransportDidReconnect(self)
Expand All @@ -290,16 +298,16 @@ extension WebSocketTransport: WebSocketDelegate {
public func websocketDidDisconnect(socket: WebSocketClient, error: Error?) {
// report any error to all subscribers
if let error = error {
self.error = WebSocketError(payload: nil, error: error, kind: .networkError)
self.error.value = WebSocketError(payload: nil, error: error, kind: .networkError)
self.notifyErrorAllHandlers(error)
} else {
self.error = nil
self.error.value = nil
}

self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error)
self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error.value)
acked = false // need new connect and ack before sending

if reconnect {
if reconnect.value {
DispatchQueue.main.asyncAfter(deadline: .now() + reconnectionInterval) {
self.websocket.connect()
}
Expand Down
13 changes: 12 additions & 1 deletion Tests/ApolloWebsocketTests/MockWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import Starscream
@testable import ApolloWebSocket

class MockWebSocket: ApolloWebSocketClient {

var callbackQueue: DispatchQueue = DispatchQueue.main

var pongDelegate: WebSocketPongDelegate?
var request: URLRequest

Expand All @@ -15,8 +18,16 @@ class MockWebSocket: ApolloWebSocketClient {
self.request = URLRequest(url: URL(string: "http://localhost:8080")!)
}

open func reportDidConnect() {
callbackQueue.async {
self.delegate?.websocketDidConnect(socket: self)
}
}

open func write(string: String, completion: (() -> ())?) {
delegate?.websocketDidReceiveMessage(socket: self, text: string)
callbackQueue.async {
self.delegate?.websocketDidReceiveMessage(socket: self, text: string)
}
}

open func write(data: Data, completion: (() -> ())?) {
Expand Down
Loading