Skip to content

Commit

Permalink
PR Feedback part 3
Browse files Browse the repository at this point in the history
  • Loading branch information
buggmagnet committed Oct 13, 2023
1 parent bbdd843 commit ceba87c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 44 deletions.
14 changes: 8 additions & 6 deletions ios/MullvadVPN/TunnelManager/Tunnel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ protocol TunnelProtocol: AnyObject {
var isOnDemandEnabled: Bool { get set }
var startDate: Date? { get }

func addObserver(_ observer: TunnelStatusObserver)
func removeObserver(_ observer: TunnelStatusObserver)
init(tunnelProvider: TunnelProviderManagerType)

func addObserver(_ observer: any TunnelStatusObserver)
func removeObserver(_ observer: any TunnelStatusObserver)
func addBlockObserver(
queue: DispatchQueue?,
handler: @escaping (any TunnelProtocol, NEVPNStatus) -> Void
Expand Down Expand Up @@ -105,7 +107,7 @@ final class Tunnel: TunnelProtocol, Equatable {
}

private let lock = NSLock()
private var observerList = ObserverList<TunnelStatusObserver>()
private var observerList = ObserverList<any TunnelStatusObserver>()

private var _startDate: Date?
internal let tunnelProvider: TunnelProviderManagerType
Expand Down Expand Up @@ -169,11 +171,11 @@ final class Tunnel: TunnelProtocol, Equatable {
return observer
}

func addObserver(_ observer: TunnelStatusObserver) {
func addObserver(_ observer: any TunnelStatusObserver) {
observerList.append(observer)
}

func removeObserver(_ observer: TunnelStatusObserver) {
func removeObserver(_ observer: any TunnelStatusObserver) {
observerList.remove(observer)
}

Expand All @@ -185,7 +187,7 @@ final class Tunnel: TunnelProtocol, Equatable {
handleVPNStatus(newStatus)

observerList.forEach { observer in
observer.tunnel(self, didReceiveStatus: newStatus)
// observer.tunnel(self, didReceiveStatus: newStatus)
}
}

Expand Down
4 changes: 2 additions & 2 deletions ios/MullvadVPN/TunnelManager/TunnelManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ final class TunnelManager: StorePaymentObserver {
// MARK: - Internal variables

private let application: BackgroundTaskProvider
fileprivate let tunnelStore: TunnelStoreProtocol
fileprivate let tunnelStore: any TunnelStoreProtocol
private let relayCacheTracker: RelayCacheTrackerProtocol
private let accountsProxy: RESTAccountHandling
private let devicesProxy: DeviceHandling
Expand Down Expand Up @@ -82,7 +82,7 @@ final class TunnelManager: StorePaymentObserver {

init(
application: BackgroundTaskProvider,
tunnelStore: TunnelStoreProtocol,
tunnelStore: any TunnelStoreProtocol,
relayCacheTracker: RelayCacheTrackerProtocol,
accountsProxy: RESTAccountHandling,
devicesProxy: DeviceHandling,
Expand Down
27 changes: 14 additions & 13 deletions ios/MullvadVPN/TunnelManager/TunnelStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@ import NetworkExtension
import UIKit

protocol TunnelStoreProtocol {
func getPersistentTunnels() -> [any TunnelProtocol]
func createNewTunnel() -> any TunnelProtocol
associatedtype TunnelType: TunnelProtocol, Equatable
func getPersistentTunnels() -> [TunnelType]
func createNewTunnel() -> TunnelType
}

/// Wrapper around system VPN tunnels.
final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver {
typealias TunnelType = Tunnel
private let logger = Logger(label: "TunnelStore")
private let lock = NSLock()

/// Persistent tunnels registered with the system.
private var persistentTunnels: [any TunnelProtocol] = []
private var persistentTunnels: [TunnelType] = []

/// Newly created tunnels, stored as collection of weak boxes.
private var newTunnels: [WeakBox<any TunnelProtocol>] = []
private var newTunnels: [WeakBox<TunnelType>] = []

init(application: UIApplication) {
NotificationCenter.default.addObserver(
Expand All @@ -36,7 +38,7 @@ final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver {
)
}

func getPersistentTunnels() -> [any TunnelProtocol] {
func getPersistentTunnels() -> [TunnelType] {
lock.lock()
defer { lock.unlock() }

Expand Down Expand Up @@ -71,12 +73,12 @@ final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver {
}
}

func createNewTunnel() -> any TunnelProtocol {
func createNewTunnel() -> TunnelType {
lock.lock()
defer { lock.unlock() }

let tunnelProviderManager = TunnelProviderManagerType()
let tunnel = Tunnel(tunnelProvider: tunnelProviderManager)
let tunnel = TunnelType(tunnelProvider: tunnelProviderManager)
tunnel.addObserver(self)

newTunnels = newTunnels.filter { $0.value != nil }
Expand All @@ -91,20 +93,19 @@ final class TunnelStore: TunnelStoreProtocol, TunnelStatusObserver {
lock.lock()
defer { lock.unlock() }

handleTunnelStatus(tunnel: tunnel, status: status)
// swiftlint:disable:next force_cast
handleTunnelStatus(tunnel: tunnel as! TunnelType, status: status)
}

private func handleTunnelStatus(tunnel: any TunnelProtocol, status: NEVPNStatus) {
guard let tunnel = tunnel as? Tunnel else { return }

private func handleTunnelStatus(tunnel: TunnelType, status: NEVPNStatus) {
if status == .invalid,
let index = persistentTunnels.map({ $0 as? Tunnel }).firstIndex(of: tunnel) {
let index = persistentTunnels.firstIndex(of: tunnel) {
persistentTunnels.remove(at: index)
logger.debug("Persistent tunnel was removed: \(tunnel.logFormat()).")
}

if status != .invalid,
let index = newTunnels.map({ $0.value as? Tunnel }).firstIndex(where: { $0 == tunnel }) {
let index = newTunnels.compactMap({ $0.value }).firstIndex(where: { $0 == tunnel }) {
newTunnels.remove(at: index)
persistentTunnels.append(tunnel)
logger.debug("New tunnel became persistent: \(tunnel.logFormat()).")
Expand Down
11 changes: 8 additions & 3 deletions ios/MullvadVPNTests/TunnelStore+Stubs.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ import Foundation
import NetworkExtension

struct TunnelStoreStub: TunnelStoreProtocol {
func getPersistentTunnels() -> [any TunnelProtocol] {
typealias TunnelType = TunnelStub
func getPersistentTunnels() -> [TunnelType] {
[]
}

func createNewTunnel() -> any TunnelProtocol {
func createNewTunnel() -> TunnelType {
TunnelStub(status: .invalid, isOnDemandEnabled: false)
}
}
Expand All @@ -23,7 +24,11 @@ class DummyTunnelStatusObserver: TunnelStatusObserver {
func tunnel(_ tunnel: any TunnelProtocol, didReceiveStatus status: NEVPNStatus) {}
}

class TunnelStub: TunnelProtocol {
final class TunnelStub: TunnelProtocol, Equatable {
convenience init(tunnelProvider: TunnelProviderManagerType) {
self.init(status: .invalid, isOnDemandEnabled: false)
}

static func == (lhs: TunnelStub, rhs: TunnelStub) -> Bool {
ObjectIdentifier(lhs) == ObjectIdentifier(rhs)
}
Expand Down
2 changes: 1 addition & 1 deletion ios/PacketTunnelCore/Actor/State+Extensions.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// State+.swift
// State+Extensions.swift
// PacketTunnelCore
//
// Created by pronebird on 08/09/2023.
Expand Down
53 changes: 34 additions & 19 deletions ios/PacketTunnelCoreTests/PacketTunnelActorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,18 @@ final class PacketTunnelActorTests: XCTestCase {
func testStopIsNoopBeforeStart() async throws {
let actor = PacketTunnelActor.mock()

let disconnectedExpectation = expectation(description: "Disconnected state")
disconnectedExpectation.isInverted = true

await expect(.disconnected, on: actor) {
disconnectedExpectation.fulfill()
}

actor.stop()
actor.stop()
actor.stop()

switch await actor.state {
case .initial: break
default: XCTFail("Actor did not start, should be in .initial state")
}
await fulfillment(of: [disconnectedExpectation], timeout: Duration.milliseconds(100).timeInterval)
}

func testStopCancelsDefaultPathObserver() async throws {
Expand Down Expand Up @@ -226,6 +230,8 @@ final class PacketTunnelActorTests: XCTestCase {
let actor = PacketTunnelActor.mock()
let connectingStateExpectation = expectation(description: "Connecting state")
let disconnectedStateExpectation = expectation(description: "Disconnected state")
let errorStateExpectation = expectation(description: "Should not enter error state")
errorStateExpectation.isInverted = true

stateSink = await actor.$state
.receive(on: DispatchQueue.main)
Expand All @@ -235,7 +241,7 @@ final class PacketTunnelActorTests: XCTestCase {
actor.setErrorState(reason: .readSettings)
connectingStateExpectation.fulfill()
case .error:
XCTFail("Should not go to error state")
errorStateExpectation.fulfill()
case .disconnected:
disconnectedStateExpectation.fulfill()
default:
Expand All @@ -245,44 +251,53 @@ final class PacketTunnelActorTests: XCTestCase {

actor.start(options: launchOptions)
actor.stop()

await fulfillment(of: [connectingStateExpectation, disconnectedStateExpectation], timeout: 1)
await fulfillment(of: [errorStateExpectation], timeout: Duration.milliseconds(100).timeInterval)
}

func testReconnectIsNoopBeforeConnecting() async throws {
let actor = PacketTunnelActor.mock()
let initialStateExpectation = expectation(description: "Expect initial state")
let reconnectingStateExpectation = expectation(description: "Expect initial state")
reconnectingStateExpectation.isInverted = true

stateSink = await actor.$state.receive(on: DispatchQueue.main).sink { newState in
if case .initial = newState {
initialStateExpectation.fulfill()
return
}
XCTFail("Should not change states before starting the actor")
let expression: (State) -> Bool = { if case .reconnecting = $0 { true } else { false } }

await expect(expression, on: actor) {
reconnectingStateExpectation.fulfill()
}

actor.reconnect(to: .random)

await fulfillment(of: [initialStateExpectation], timeout: 1)
await fulfillment(
of: [reconnectingStateExpectation],
timeout: Duration.milliseconds(100).timeInterval
)
}

func testCannotReconnectAfterStopping() async throws {
let actor = PacketTunnelActor.mock()

let disconnectedStateExpectation = expectation(description: "Expect disconnected state")

await expect(.disconnected, on: actor) {
disconnectedStateExpectation.fulfill()
}
await expect(.disconnected, on: actor) { disconnectedStateExpectation.fulfill() }

actor.start(options: launchOptions)
actor.stop()

await fulfillment(of: [disconnectedStateExpectation], timeout: 1)

await expect(.initial, on: actor) {
XCTFail("Should not be trying to reconnect after stopping")
}
let reconnectingStateExpectation = expectation(description: "Expect initial state")
reconnectingStateExpectation.isInverted = true
let expression: (State) -> Bool = { if case .reconnecting = $0 { true } else { false } }

await expect(expression, on: actor) { reconnectingStateExpectation.fulfill() }

actor.reconnect(to: .random)
await fulfillment(
of: [reconnectingStateExpectation],
timeout: Duration.milliseconds(100).timeInterval
)
}

func testReconnectionStopsTunnelMonitor() async throws {
Expand Down

0 comments on commit ceba87c

Please sign in to comment.