Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing data races in subscriptions #880

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ApolloWebSocket.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
9F28B6D520720F2F00144A00 /* Apollo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D420720F2F00144A00 /* Apollo.framework */; };
9F28B6D920720FD200144A00 /* ApolloTestSupport.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */; };
9F28B6DB2072101200144A00 /* StarWarsAPI.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6DA2072101200144A00 /* StarWarsAPI.framework */; };
D143FF0D236EBD6B00E20A5C /* AtomicCounter.swift in Sources */ = {isa = PBXBuildFile; fileRef = D143FF09236EBD6B00E20A5C /* AtomicCounter.swift */; };
D143FF0E236EBD6B00E20A5C /* Atomic.swift in Sources */ = {isa = PBXBuildFile; fileRef = D143FF0A236EBD6B00E20A5C /* Atomic.swift */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand Down Expand Up @@ -73,6 +75,8 @@
9F28B6D420720F2F00144A00 /* Apollo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = Apollo.framework; sourceTree = BUILT_PRODUCTS_DIR; };
9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = ApolloTestSupport.framework; sourceTree = BUILT_PRODUCTS_DIR; };
9F28B6DA2072101200144A00 /* StarWarsAPI.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StarWarsAPI.framework; sourceTree = BUILT_PRODUCTS_DIR; };
D143FF09236EBD6B00E20A5C /* AtomicCounter.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AtomicCounter.swift; sourceTree = "<group>"; };
D143FF0A236EBD6B00E20A5C /* Atomic.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Atomic.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -102,6 +106,7 @@
isa = PBXGroup;
children = (
72707469206D111A00C131F6 /* ApolloWebSocket */,
D143FF08236EBD6B00E20A5C /* Utilities */,
72707470206D112900C131F6 /* ApolloWebSocketTests */,
90690D0A2243342000FC2E54 /* Configuration */,
720F147A206AB52F00D061DB /* Frameworks */,
Expand Down Expand Up @@ -181,6 +186,15 @@
name = Products;
sourceTree = "<group>";
};
D143FF08236EBD6B00E20A5C /* Utilities */ = {
isa = PBXGroup;
children = (
D143FF0A236EBD6B00E20A5C /* Atomic.swift */,
D143FF09236EBD6B00E20A5C /* AtomicCounter.swift */,
);
path = Utilities;
sourceTree = "<group>";
};
/* End PBXGroup section */

/* Begin PBXNativeTarget section */
Expand Down Expand Up @@ -307,9 +321,11 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
D143FF0D236EBD6B00E20A5C /* AtomicCounter.swift in Sources */,
9B1CCDDF236110C3007C9032 /* WebSocketError.swift in Sources */,
7270746D206D111A00C131F6 /* SplitNetworkTransport.swift in Sources */,
9B1CCDE323611606007C9032 /* WebSocketTask.swift in Sources */,
D143FF0E236EBD6B00E20A5C /* Atomic.swift in Sources */,
7270746E206D111A00C131F6 /* WebSocketTransport.swift in Sources */,
9B1CCDE123611580007C9032 /* OperationMessage.swift in Sources */,
9B1CCDDB23610CDC007C9032 /* ApolloWebSocket.swift in Sources */,
Expand Down
3 changes: 3 additions & 0 deletions Sources/ApolloWebSocket/ApolloWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ public protocol ApolloWebSocketClient: WebSocketClient {

/// The URLRequest used on connection.
var request: URLRequest { get set }

/// Queue where the callbacks are executed
var callbackQueue: DispatchQueue { get set }
designatednerd marked this conversation as resolved.
Show resolved Hide resolved
}

// MARK: - WebSocket
Expand Down
68 changes: 38 additions & 30 deletions Sources/ApolloWebSocket/WebSocketTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public class WebSocketTransport {
public static var provider: ApolloWebSocketClient.Type = ApolloWebSocket.self
public weak var delegate: WebSocketTransportDelegate?

var reconnect = false
var reconnect: Atomic<Bool> = Atomic(false)
var websocket: ApolloWebSocketClient
var error: Error? = nil
var error: Atomic<Error?> = Atomic(nil)
let serializationFormat = JSONSerializationFormat.self
private let requestCreator: RequestCreator

Expand All @@ -40,10 +40,11 @@ public class WebSocketTransport {

private var subscribers = [String: (Result<JSONObject, Error>) -> Void]()
private var subscriptions : [String: String] = [:]
private let processingQueue = DispatchQueue(label: "com.apollographql.WebSocketTransport")

private let sendOperationIdentifiers: Bool
private let reconnectionInterval: TimeInterval
fileprivate var sequenceNumber = 0
fileprivate var sequenceNumberCounter = AtomicCounter()
fileprivate var reconnected = false

/// NOTE: Setting this won't override immediately if the socket is still connected, only on reconnection.
Expand Down Expand Up @@ -87,6 +88,7 @@ public class WebSocketTransport {
self.websocket.request.setValue(self.clientVersion, forHTTPHeaderField: WebSocketTransport.headerFieldNameClientVersion)
self.websocket.delegate = self
self.websocket.connect()
self.websocket.callbackQueue = processingQueue
}

public func isConnected() -> Bool {
Expand Down Expand Up @@ -174,7 +176,7 @@ public class WebSocketTransport {
}

public func initServer(reconnect: Bool = true) {
self.reconnect = reconnect
self.reconnect.value = reconnect
self.acked = false

if let str = OperationMessage(payload: self.connectingPayload, type: .connectionInit).rawMessage {
Expand All @@ -184,12 +186,17 @@ public class WebSocketTransport {
}

public func closeConnection() {
self.reconnect = false
if let str = OperationMessage(type: .connectionTerminate).rawMessage {
write(str)
self.reconnect.value = false

let str = OperationMessage(type: .connectionTerminate).rawMessage
processingQueue.async {
if let str = str {
self.write(str)
}

self.queue.removeAll()
self.subscriptions.removeAll()
}
self.queue.removeAll()
self.subscriptions.removeAll()
}

private func write(_ str: String, force forced: Bool = false, id: Int? = nil) {
Expand All @@ -213,43 +220,44 @@ public class WebSocketTransport {
websocket.delegate = nil
}

private func nextSequenceNumber() -> Int {
sequenceNumber += 1
return sequenceNumber
}

func sendHelper<Operation: GraphQLOperation>(operation: Operation, resultHandler: @escaping (_ result: Result<JSONObject, Error>) -> Void) -> String? {
let body = requestCreator.requestBody(for: operation, sendOperationIdentifiers: self.sendOperationIdentifiers)
let sequenceNumber = "\(nextSequenceNumber())"
let sequenceNumber = "\(sequenceNumberCounter.next())"

guard let message = OperationMessage(payload: body, id: sequenceNumber).rawMessage else {
return nil
}

write(message)

processingQueue.async {
self.write(message)

subscribers[sequenceNumber] = resultHandler
if operation.operationType == .subscription {
subscriptions[sequenceNumber] = message
self.subscribers[sequenceNumber] = resultHandler
if operation.operationType == .subscription {
self.subscriptions[sequenceNumber] = message
}
}

return sequenceNumber
}

public func unsubscribe(_ subscriptionId: String) {
if let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage {
write(str)
let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage

processingQueue.async {
if let str = str {
self.write(str)
}
self.subscribers.removeValue(forKey: subscriptionId)
self.subscriptions.removeValue(forKey: subscriptionId)
}
subscribers.removeValue(forKey: subscriptionId)
subscriptions.removeValue(forKey: subscriptionId)
}
}

// MARK: - HTTPNetworkTransport conformance

extension WebSocketTransport: NetworkTransport {
public func send<Operation>(operation: Operation, completionHandler: @escaping (_ result: Result<GraphQLResponse<Operation>,Error>) -> Void) -> Cancellable {
if let error = self.error {
if let error = self.error.value {
completionHandler(.failure(error))
return EmptyCancellable()
}
Expand All @@ -271,7 +279,7 @@ extension WebSocketTransport: NetworkTransport {
extension WebSocketTransport: WebSocketDelegate {

public func websocketDidConnect(socket: WebSocketClient) {
self.error = nil
self.error.value = nil
initServer()
if reconnected {
self.delegate?.webSocketTransportDidReconnect(self)
Expand All @@ -290,16 +298,16 @@ extension WebSocketTransport: WebSocketDelegate {
public func websocketDidDisconnect(socket: WebSocketClient, error: Error?) {
// report any error to all subscribers
if let error = error {
self.error = WebSocketError(payload: nil, error: error, kind: .networkError)
self.error.value = WebSocketError(payload: nil, error: error, kind: .networkError)
self.notifyErrorAllHandlers(error)
} else {
self.error = nil
self.error.value = nil
}

self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error)
self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error.value)
acked = false // need new connect and ack before sending

if reconnect {
if reconnect.value {
DispatchQueue.main.asyncAfter(deadline: .now() + reconnectionInterval) {
self.websocket.connect()
}
Expand Down
13 changes: 12 additions & 1 deletion Tests/ApolloWebsocketTests/MockWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import Starscream
@testable import ApolloWebSocket

class MockWebSocket: ApolloWebSocketClient {

var callbackQueue: DispatchQueue = DispatchQueue.main

var pongDelegate: WebSocketPongDelegate?
var request: URLRequest

Expand All @@ -15,8 +18,16 @@ class MockWebSocket: ApolloWebSocketClient {
self.request = URLRequest(url: URL(string: "http://localhost:8080")!)
}

open func reportDidConnect() {
callbackQueue.async {
self.delegate?.websocketDidConnect(socket: self)
}
}

open func write(string: String, completion: (() -> ())?) {
delegate?.websocketDidReceiveMessage(socket: self, text: string)
callbackQueue.async {
self.delegate?.websocketDidReceiveMessage(socket: self, text: string)
}
}

open func write(data: Data, completion: (() -> ())?) {
Expand Down
109 changes: 107 additions & 2 deletions Tests/ApolloWebsocketTests/StarWarsSubscriptionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ class StarWarsSubscriptionTests: XCTestCase {
let SERVER: String = "ws://localhost:8080/websocket"

var client: ApolloClient!
var webSocketTransport: WebSocketTransport!

override func setUp() {
super.setUp()

let networkTransport = WebSocketTransport(request: URLRequest(url: URL(string: SERVER)!))
client = ApolloClient(networkTransport: networkTransport)
WebSocketTransport.provider = ApolloWebSocket.self
webSocketTransport = WebSocketTransport(request: URLRequest(url: URL(string: SERVER)!))
client = ApolloClient(networkTransport: webSocketTransport)
}

// MARK: Subscriptions
Expand Down Expand Up @@ -252,4 +254,107 @@ class StarWarsSubscriptionTests: XCTestCase {
subJedi.cancel()
subNewHope.cancel()
}

// MARK: Data races tests

func testConcurrentSubscribing() {
let firstSubscription = ReviewAddedSubscription(episode: .empire)
let secondSubscription = ReviewAddedSubscription(episode: .empire)

let expectation = self.expectation(description: "Subscribers connected and received events")
expectation.expectedFulfillmentCount = 2

var sub1: Cancellable?
var sub2: Cancellable?

let queue = DispatchQueue(label: "com.apollographql.testing", attributes: .concurrent)

queue.async {
sub1 = self.client.subscribe(subscription: firstSubscription) { _ in
expectation.fulfill()
}
}

queue.async {
sub2 = self.client.subscribe(subscription: secondSubscription) { _ in
expectation.fulfill()
}
}

// dispatched with a barrier flag to make sure
// this is performed after subscription calls
queue.sync(flags: .barrier) {
// dispatched on the processing queue to make sure
// this is performed after subscribers are processed
self.webSocketTransport.websocket.callbackQueue.async {
_ = self.client.perform(mutation: CreateReviewForEpisodeMutation(episode: .empire, review: ReviewInput(stars: 5, commentary: "The greatest movie ever!")))
}
}

waitForExpectations(timeout: 10, handler: nil)
sub1?.cancel()
sub2?.cancel()
}

func testConcurrentSubscriptionCancellations() {
let firstSubscription = ReviewAddedSubscription(episode: .empire)
let secondSubscription = ReviewAddedSubscription(episode: .empire)

let expectation = self.expectation(description: "Subscriptions cancelled")
expectation.expectedFulfillmentCount = 2

let sub1 = client.subscribe(subscription: firstSubscription) { _ in }
let sub2 = client.subscribe(subscription: secondSubscription) { _ in }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a post of a review, then failing either of these if something comes through since they're both being cancelled immediately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that if you think it looks better, but I think it wouldn't really matter.

I feel like these tests only make sense if TSAN is turned on anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not particularly concerned about the way it looks, but it'd be nice to have a way to test these without TSAN turn on since right now we're not able to turn it on for the library as a whole (yet)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^still thinking on this - is this something where we should at least be able to say "These should never get called, even without TSAN on?'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see what you mean. Kinda misunderstood you before.

I will perform a mutation after calling cancel to confirm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I'm trying to do:

let firstSubscription = ReviewAddedSubscription(episode: .empire)
let secondSubscription = ReviewAddedSubscription(episode: .empire)
let expectation = self.expectation(description: "Subscriptions cancelled")
expectation.expectedFulfillmentCount = 2
    
let sub1 = client.subscribe(subscription: firstSubscription) { _ in
  XCTFail("Unexpected subscription response")
}
let sub2 = client.subscribe(subscription: secondSubscription) { _ in
  XCTFail("Unexpected subscription response")
}

let queue = DispatchQueue(label: "com.apollographql.testing", attributes: .concurrent)

queue.async {
  sub1.cancel()
  expectation.fulfill()
}
queue.async {
  sub2.cancel()
  expectation.fulfill()
}
queue.async(flags: .barrier) {
  _ = self.client.perform(mutation: CreateReviewForEpisodeMutation(episode: .empire, review: ReviewInput(stars: 5, commentary: "The greatest movie ever!")))
}

waitForExpectations(timeout: 10, handler: nil)

But what happens is that final mutation sometimes triggers a subscription response in another test - testMultipleSubscriptions. Its callbacks get called too many times and the whole test fails. Testing this way can be achieved if the test sleeps after waitForExpectations for a while, but I feel thats just messy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think waiting for expectations is totally reasonable - you also can do an inverted expectation where the test fails if it actually gets called - so instead of having the explicit XCTFails you can just have an invertedExpectation.fulfill(). You'd probably only need to wait that one 1-2 seconds, since the perform mutation should happen a lot faster than that.


DispatchQueue.global().async {
sub1.cancel()
expectation.fulfill()
designatednerd marked this conversation as resolved.
Show resolved Hide resolved
}
DispatchQueue.global().async {
sub2.cancel()
expectation.fulfill()
}

waitForExpectations(timeout: 10, handler: nil)
}

func testConcurrentSubscriptionAndConnectionClose() {
let empireReviewSubscription = ReviewAddedSubscription(episode: .empire)
let expectation = self.expectation(description: "Connection closed")

DispatchQueue.global().async {
let sub = self.client.subscribe(subscription: empireReviewSubscription) { _ in }
aivcec marked this conversation as resolved.
Show resolved Hide resolved
sub.cancel()
}

_ = self.client.perform(mutation: CreateReviewForEpisodeMutation(episode: .empire, review: ReviewInput(stars: 5, commentary: "The greatest movie ever!")))

DispatchQueue.global().async {
self.webSocketTransport.closeConnection()
expectation.fulfill()
designatednerd marked this conversation as resolved.
Show resolved Hide resolved
}

waitForExpectations(timeout: 10, handler: nil)
}

func testConcurrentConnectAndCloseConnection() {
WebSocketTransport.provider = MockWebSocket.self
let webSocketTransport = WebSocketTransport(request: URLRequest(url: URL(string: SERVER)!))
let expectation = self.expectation(description: "Connection closed")
expectation.expectedFulfillmentCount = 2

DispatchQueue.global().async {
if let websocket = webSocketTransport.websocket as? MockWebSocket {
websocket.reportDidConnect()
expectation.fulfill()
}
}

DispatchQueue.global().async {
webSocketTransport.closeConnection()
expectation.fulfill()
}

waitForExpectations(timeout: 10, handler: nil)
}
}
25 changes: 25 additions & 0 deletions Utilities/Atomic.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import Foundation
designatednerd marked this conversation as resolved.
Show resolved Hide resolved

class Atomic<T> {
private let lock = NSLock()
private var _value: T

init(_ value: T) {
_value = value
}

var value: T {
get {
lock.lock()
defer { lock.unlock() }

return _value
}
set {
lock.lock()
defer { lock.unlock() }

_value = newValue
}
}
}
Loading