From ceba87ce304a79d2f89e19065831b1ff648df926 Mon Sep 17 00:00:00 2001 From: Bug Magnet Date: Fri, 13 Oct 2023 16:47:51 +0200 Subject: [PATCH] PR Feedback part 3 --- ios/MullvadVPN/TunnelManager/Tunnel.swift | 14 ++--- .../TunnelManager/TunnelManager.swift | 4 +- .../TunnelManager/TunnelStore.swift | 27 +++++----- ios/MullvadVPNTests/TunnelStore+Stubs.swift | 11 ++-- .../Actor/State+Extensions.swift | 2 +- .../PacketTunnelActorTests.swift | 53 ++++++++++++------- 6 files changed, 67 insertions(+), 44 deletions(-) diff --git a/ios/MullvadVPN/TunnelManager/Tunnel.swift b/ios/MullvadVPN/TunnelManager/Tunnel.swift index 9edb3f1bc700..da4359c3c5d9 100644 --- a/ios/MullvadVPN/TunnelManager/Tunnel.swift +++ b/ios/MullvadVPN/TunnelManager/Tunnel.swift @@ -25,8 +25,10 @@ protocol TunnelProtocol: AnyObject { var isOnDemandEnabled: Bool { get set } var startDate: Date? { get } - func addObserver(_ observer: TunnelStatusObserver) - func removeObserver(_ observer: TunnelStatusObserver) + init(tunnelProvider: TunnelProviderManagerType) + + func addObserver(_ observer: any TunnelStatusObserver) + func removeObserver(_ observer: any TunnelStatusObserver) func addBlockObserver( queue: DispatchQueue?, handler: @escaping (any TunnelProtocol, NEVPNStatus) -> Void @@ -105,7 +107,7 @@ final class Tunnel: TunnelProtocol, Equatable { } private let lock = NSLock() - private var observerList = ObserverList() + private var observerList = ObserverList() private var _startDate: Date? internal let tunnelProvider: TunnelProviderManagerType @@ -169,11 +171,11 @@ final class Tunnel: TunnelProtocol, Equatable { return observer } - func addObserver(_ observer: TunnelStatusObserver) { + func addObserver(_ observer: any TunnelStatusObserver) { observerList.append(observer) } - func removeObserver(_ observer: TunnelStatusObserver) { + func removeObserver(_ observer: any TunnelStatusObserver) { observerList.remove(observer) } @@ -185,7 +187,7 @@ final class Tunnel: TunnelProtocol, Equatable { handleVPNStatus(newStatus) observerList.forEach { observer in - observer.tunnel(self, didReceiveStatus: newStatus) +// observer.tunnel(self, didReceiveStatus: newStatus) } } diff --git a/ios/MullvadVPN/TunnelManager/TunnelManager.swift b/ios/MullvadVPN/TunnelManager/TunnelManager.swift index 138d52d7896b..768b4fb3aee8 100644 --- a/ios/MullvadVPN/TunnelManager/TunnelManager.swift +++ b/ios/MullvadVPN/TunnelManager/TunnelManager.swift @@ -45,7 +45,7 @@ final class TunnelManager: StorePaymentObserver { // MARK: - Internal variables private let application: BackgroundTaskProvider - fileprivate let tunnelStore: TunnelStoreProtocol + fileprivate let tunnelStore: any TunnelStoreProtocol private let relayCacheTracker: RelayCacheTrackerProtocol private let accountsProxy: RESTAccountHandling private let devicesProxy: DeviceHandling @@ -82,7 +82,7 @@ final class TunnelManager: StorePaymentObserver { init( application: BackgroundTaskProvider, - tunnelStore: TunnelStoreProtocol, + tunnelStore: any TunnelStoreProtocol, relayCacheTracker: RelayCacheTrackerProtocol, accountsProxy: RESTAccountHandling, devicesProxy: DeviceHandling, diff --git a/ios/MullvadVPN/TunnelManager/TunnelStore.swift b/ios/MullvadVPN/TunnelManager/TunnelStore.swift index f750ea9c4f8e..b93c33ac45c2 100644 --- a/ios/MullvadVPN/TunnelManager/TunnelStore.swift +++ b/ios/MullvadVPN/TunnelManager/TunnelStore.swift @@ -12,20 +12,22 @@ import NetworkExtension import UIKit protocol TunnelStoreProtocol { - func getPersistentTunnels() -> [any TunnelProtocol] - func createNewTunnel() -> any TunnelProtocol + associatedtype TunnelType: TunnelProtocol, Equatable + func getPersistentTunnels() -> [TunnelType] + func createNewTunnel() -> TunnelType } /// Wrapper around system VPN tunnels. final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver { + typealias TunnelType = Tunnel private let logger = Logger(label: "TunnelStore") private let lock = NSLock() /// Persistent tunnels registered with the system. - private var persistentTunnels: [any TunnelProtocol] = [] + private var persistentTunnels: [TunnelType] = [] /// Newly created tunnels, stored as collection of weak boxes. - private var newTunnels: [WeakBox] = [] + private var newTunnels: [WeakBox] = [] init(application: UIApplication) { NotificationCenter.default.addObserver( @@ -36,7 +38,7 @@ final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver { ) } - func getPersistentTunnels() -> [any TunnelProtocol] { + func getPersistentTunnels() -> [TunnelType] { lock.lock() defer { lock.unlock() } @@ -71,12 +73,12 @@ final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver { } } - func createNewTunnel() -> any TunnelProtocol { + func createNewTunnel() -> TunnelType { lock.lock() defer { lock.unlock() } let tunnelProviderManager = TunnelProviderManagerType() - let tunnel = Tunnel(tunnelProvider: tunnelProviderManager) + let tunnel = TunnelType(tunnelProvider: tunnelProviderManager) tunnel.addObserver(self) newTunnels = newTunnels.filter { $0.value != nil } @@ -91,20 +93,19 @@ final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver { lock.lock() defer { lock.unlock() } - handleTunnelStatus(tunnel: tunnel, status: status) + // swiftlint:disable:next force_cast + handleTunnelStatus(tunnel: tunnel as! TunnelType, status: status) } - private func handleTunnelStatus(tunnel: any TunnelProtocol, status: NEVPNStatus) { - guard let tunnel = tunnel as? Tunnel else { return } - + private func handleTunnelStatus(tunnel: TunnelType, status: NEVPNStatus) { if status == .invalid, - let index = persistentTunnels.map({ $0 as? Tunnel }).firstIndex(of: tunnel) { + let index = persistentTunnels.firstIndex(of: tunnel) { persistentTunnels.remove(at: index) logger.debug("Persistent tunnel was removed: \(tunnel.logFormat()).") } if status != .invalid, - let index = newTunnels.map({ $0.value as? Tunnel }).firstIndex(where: { $0 == tunnel }) { + let index = newTunnels.compactMap({ $0.value }).firstIndex(where: { $0 == tunnel }) { newTunnels.remove(at: index) persistentTunnels.append(tunnel) logger.debug("New tunnel became persistent: \(tunnel.logFormat()).") diff --git a/ios/MullvadVPNTests/TunnelStore+Stubs.swift b/ios/MullvadVPNTests/TunnelStore+Stubs.swift index d9a15f68146b..546583942628 100644 --- a/ios/MullvadVPNTests/TunnelStore+Stubs.swift +++ b/ios/MullvadVPNTests/TunnelStore+Stubs.swift @@ -10,11 +10,12 @@ import Foundation import NetworkExtension struct TunnelStoreStub: TunnelStoreProtocol { - func getPersistentTunnels() -> [any TunnelProtocol] { + typealias TunnelType = TunnelStub + func getPersistentTunnels() -> [TunnelType] { [] } - func createNewTunnel() -> any TunnelProtocol { + func createNewTunnel() -> TunnelType { TunnelStub(status: .invalid, isOnDemandEnabled: false) } } @@ -23,7 +24,11 @@ class DummyTunnelStatusObserver: TunnelStatusObserver { func tunnel(_ tunnel: any TunnelProtocol, didReceiveStatus status: NEVPNStatus) {} } -class TunnelStub: TunnelProtocol { +final class TunnelStub: TunnelProtocol, Equatable { + convenience init(tunnelProvider: TunnelProviderManagerType) { + self.init(status: .invalid, isOnDemandEnabled: false) + } + static func == (lhs: TunnelStub, rhs: TunnelStub) -> Bool { ObjectIdentifier(lhs) == ObjectIdentifier(rhs) } diff --git a/ios/PacketTunnelCore/Actor/State+Extensions.swift b/ios/PacketTunnelCore/Actor/State+Extensions.swift index 4a232fb84dd1..9bceaa91ee54 100644 --- a/ios/PacketTunnelCore/Actor/State+Extensions.swift +++ b/ios/PacketTunnelCore/Actor/State+Extensions.swift @@ -1,5 +1,5 @@ // -// State+.swift +// State+Extensions.swift // PacketTunnelCore // // Created by pronebird on 08/09/2023. diff --git a/ios/PacketTunnelCoreTests/PacketTunnelActorTests.swift b/ios/PacketTunnelCoreTests/PacketTunnelActorTests.swift index 39b6078f2b58..969d9c70df95 100644 --- a/ios/PacketTunnelCoreTests/PacketTunnelActorTests.swift +++ b/ios/PacketTunnelCoreTests/PacketTunnelActorTests.swift @@ -185,14 +185,18 @@ final class PacketTunnelActorTests: XCTestCase { func testStopIsNoopBeforeStart() async throws { let actor = PacketTunnelActor.mock() + let disconnectedExpectation = expectation(description: "Disconnected state") + disconnectedExpectation.isInverted = true + + await expect(.disconnected, on: actor) { + disconnectedExpectation.fulfill() + } + actor.stop() actor.stop() actor.stop() - switch await actor.state { - case .initial: break - default: XCTFail("Actor did not start, should be in .initial state") - } + await fulfillment(of: [disconnectedExpectation], timeout: Duration.milliseconds(100).timeInterval) } func testStopCancelsDefaultPathObserver() async throws { @@ -226,6 +230,8 @@ final class PacketTunnelActorTests: XCTestCase { let actor = PacketTunnelActor.mock() let connectingStateExpectation = expectation(description: "Connecting state") let disconnectedStateExpectation = expectation(description: "Disconnected state") + let errorStateExpectation = expectation(description: "Should not enter error state") + errorStateExpectation.isInverted = true stateSink = await actor.$state .receive(on: DispatchQueue.main) @@ -235,7 +241,7 @@ final class PacketTunnelActorTests: XCTestCase { actor.setErrorState(reason: .readSettings) connectingStateExpectation.fulfill() case .error: - XCTFail("Should not go to error state") + errorStateExpectation.fulfill() case .disconnected: disconnectedStateExpectation.fulfill() default: @@ -245,24 +251,28 @@ final class PacketTunnelActorTests: XCTestCase { actor.start(options: launchOptions) actor.stop() + await fulfillment(of: [connectingStateExpectation, disconnectedStateExpectation], timeout: 1) + await fulfillment(of: [errorStateExpectation], timeout: Duration.milliseconds(100).timeInterval) } func testReconnectIsNoopBeforeConnecting() async throws { let actor = PacketTunnelActor.mock() - let initialStateExpectation = expectation(description: "Expect initial state") + let reconnectingStateExpectation = expectation(description: "Expect initial state") + reconnectingStateExpectation.isInverted = true - stateSink = await actor.$state.receive(on: DispatchQueue.main).sink { newState in - if case .initial = newState { - initialStateExpectation.fulfill() - return - } - XCTFail("Should not change states before starting the actor") + let expression: (State) -> Bool = { if case .reconnecting = $0 { true } else { false } } + + await expect(expression, on: actor) { + reconnectingStateExpectation.fulfill() } actor.reconnect(to: .random) - await fulfillment(of: [initialStateExpectation], timeout: 1) + await fulfillment( + of: [reconnectingStateExpectation], + timeout: Duration.milliseconds(100).timeInterval + ) } func testCannotReconnectAfterStopping() async throws { @@ -270,19 +280,24 @@ final class PacketTunnelActorTests: XCTestCase { let disconnectedStateExpectation = expectation(description: "Expect disconnected state") - await expect(.disconnected, on: actor) { - disconnectedStateExpectation.fulfill() - } + await expect(.disconnected, on: actor) { disconnectedStateExpectation.fulfill() } actor.start(options: launchOptions) actor.stop() await fulfillment(of: [disconnectedStateExpectation], timeout: 1) - await expect(.initial, on: actor) { - XCTFail("Should not be trying to reconnect after stopping") - } + let reconnectingStateExpectation = expectation(description: "Expect initial state") + reconnectingStateExpectation.isInverted = true + let expression: (State) -> Bool = { if case .reconnecting = $0 { true } else { false } } + + await expect(expression, on: actor) { reconnectingStateExpectation.fulfill() } + actor.reconnect(to: .random) + await fulfillment( + of: [reconnectingStateExpectation], + timeout: Duration.milliseconds(100).timeInterval + ) } func testReconnectionStopsTunnelMonitor() async throws {