Skip to content

Commit

Permalink
Shutdown all active peer multiplexers when worker is stopping (#259)
Browse files Browse the repository at this point in the history
* Shutdown all active peer multiplexers when worker is stopping

* Wait for them to actually disconnect

* adjust

* fix tests

* Assert that cleanup is working

---------

Co-authored-by: Pascal Vizeli <[email protected]>
  • Loading branch information
ludeeus and pvizeli authored Apr 30, 2024
1 parent 8ffce1f commit 4f7e4d9
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 8 deletions.
25 changes: 23 additions & 2 deletions snitun/server/peer_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Manage peer connections."""

from __future__ import annotations

import asyncio
Expand All @@ -8,6 +9,7 @@
import logging
from typing import Callable

import async_timeout
from cryptography.fernet import Fernet, InvalidToken, MultiFernet

from ..exceptions import SniTunInvalidPeer
Expand Down Expand Up @@ -37,7 +39,7 @@ def __init__(
self._loop = asyncio.get_event_loop()
self._throttling = throttling
self._event_callback = event_callback
self._peers = {}
self._peers: dict[str, Peer] = {}

@property
def connections(self) -> int:
Expand Down Expand Up @@ -96,7 +98,9 @@ def remove_peer(self, peer: Peer) -> None:

if self._event_callback:
self._loop.call_soon(
self._event_callback, peer, PeerManagerEvent.DISCONNECTED,
self._event_callback,
peer,
PeerManagerEvent.DISCONNECTED,
)

def peer_available(self, hostname: str) -> bool:
Expand All @@ -108,3 +112,20 @@ def peer_available(self, hostname: str) -> bool:
def get_peer(self, hostname: str) -> Peer | None:
"""Get peer."""
return self._peers.get(hostname)

async def close_connections(self, timeout: int = 10) -> None:
"""Close all peer connections.
Use this function only if you do not controll the server socket.
"""
peers = list(self._peers.values())
for peer in peers:
if peer.is_connected:
peer.multiplexer.shutdown()

if waiters := [peer.wait_disconnect() for peer in peers]:
try:
async with async_timeout.timeout(timeout):
await asyncio.gather(*waiters, return_exceptions=True)
except asyncio.TimeoutError:
_LOGGER.error("Timeout while waiting for peer disconnect")
20 changes: 16 additions & 4 deletions snitun/server/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SniTun worker for traffics."""

from __future__ import annotations

import asyncio
Expand Down Expand Up @@ -83,7 +84,10 @@ def shutdown(self) -> None:
self.join(10)

def handover_connection(
self, con: socket, data: bytes, sni: str | None = None,
self,
con: socket,
data: bytes,
sni: str | None = None,
) -> None:
"""Move new connection to worker."""
self._new.put_nowait((con, data, sni))
Expand All @@ -110,16 +114,24 @@ def run(self) -> None:

new[0].setblocking(False)
asyncio.run_coroutine_threadsafe(
self._async_new_connection(*new), loop=self._loop,
self._async_new_connection(*new),
loop=self._loop,
)

# Shutdown worker
_LOGGER.info("Stop worker: %s", self.name)
_LOGGER.info("Stoping worker: %s", self.name)
asyncio.run_coroutine_threadsafe(
self._peers.close_connections(),
loop=self._loop,
).result()
self._loop.call_soon_threadsafe(self._loop.stop)
running_loop.join(10)

async def _async_new_connection(
self, con: socket, data: bytes, sni: str | None,
self,
con: socket,
data: bytes,
sni: str | None,
) -> None:
"""Handle incoming connection."""
try:
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_server_sync(event_loop):
shutdown = False

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", 8366))
sock.listen(2)
sock.setblocking(False)
Expand Down
17 changes: 15 additions & 2 deletions tests/server/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the server worker."""

from datetime import datetime, timedelta, timezone
import hashlib
import os
Expand Down Expand Up @@ -49,6 +50,8 @@ def test_peer_connection(test_server_sync, test_client_sync, event_loop):

worker.shutdown()

assert worker.peer_size == 0


def test_peer_connection_disconnect(test_server_sync, test_client_sync, event_loop):
"""Run a full flow of with a peer & disconnect."""
Expand Down Expand Up @@ -81,7 +84,10 @@ def test_peer_connection_disconnect(test_server_sync, test_client_sync, event_lo


def test_sni_connection(
test_server_sync, test_client_sync, test_client_ssl_sync, event_loop
test_server_sync,
test_client_sync,
test_client_ssl_sync,
event_loop,
):
"""Run a full flow of with a peer."""
worker = ServerWorker(FERNET_TOKENS)
Expand All @@ -91,7 +97,11 @@ def test_sni_connection(
hostname = "localhost"
alias = ["localhost.custom"]
fernet_token = create_peer_config(
valid.timestamp(), hostname, aes_key, aes_iv, alias=alias
valid.timestamp(),
hostname,
aes_key,
aes_iv,
alias=alias,
)

worker.start()
Expand All @@ -111,4 +121,7 @@ def test_sni_connection(
worker.handover_connection(test_server_sync[1], TLS_1_2, hostname)
assert len(test_client_sync.recv(1048)) == 32

assert worker.peer_size == 1
worker.shutdown()

assert worker.peer_size == 0

0 comments on commit 4f7e4d9

Please sign in to comment.