From 3ad42de4c928228ca8a0b582dda3af37e08b3a51 Mon Sep 17 00:00:00 2001 From: JarbasAI <33701864+JarbasAl@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:59:14 +0100 Subject: [PATCH] fix/session_mapping (#94) * fix/session_mapping ensure session is tracked properly per client, allow it to be sent in initial handshake * fix/session_mapping ensure session is tracked properly per client, allow it to be sent in initial handshake * fix/session_mapping ensure session is tracked properly per client, allow it to be sent in initial handshake * fix/session_mapping ensure session is tracked properly per client, allow it to be sent in initial handshake --- hivemind_core/protocol.py | 40 +++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/hivemind_core/protocol.py b/hivemind_core/protocol.py index f4bad8d..3a604a8 100644 --- a/hivemind_core/protocol.py +++ b/hivemind_core/protocol.py @@ -1,4 +1,5 @@ import json +import uuid from dataclasses import dataclass, field from enum import Enum, IntEnum from typing import List, Dict, Optional @@ -260,7 +261,6 @@ def get_bus(self, client: HiveMindClientConnection): def handle_new_client(self, client: HiveMindClientConnection): LOG.debug(f"new client: {client.peer}") - self.clients[client.peer] = client message = Message( "hive.client.connect", {"ip": client.ip, "session_id": client.sess.session_id}, @@ -398,6 +398,8 @@ def handle_handshake_message( ): LOG.debug("handshake received, generating session key") payload = message.payload + if "session" in payload: + client.sess = Session.deserialize(payload["session"]) if "site_id" in payload: client.sess.site_id = client.site_id = payload["site_id"] if "pubkey" in payload and client.handshake is not None: @@ -445,22 +447,36 @@ def handle_handshake_message( client.socket.close() return + LOG.debug(f"client site_id: {client.sess.site_id}") + if client.sess.session_id != "default": + LOG.debug(f"client session_id: {client.sess.session_id}") + self.clients[client.peer] = client + else: + LOG.warning("client did not send a session in it's handshake") + msg = HiveMessage(HiveMessageType.HANDSHAKE, payload) client.send(msg) # client can recreate crypto_key on his side now def handle_bus_message( self, message: HiveMessage, client: HiveMindClientConnection ): - # update the session as received by the client - old = client.peer - client.sess = Session.from_message(message.payload) - LOG.debug(f"Client session updated: {client.sess.serialize()}") - if old != client.peer: - LOG.debug(f"Client session_id changed! new peer_id: {client.peer}") - if old in self.clients: - self.clients[client.peer] = self.clients.pop(old) + # track any Session updates from client side + sess = Session.from_message(message.payload) + if client.sess.session_id == "default": + LOG.warning(f"{client.peer} did not send a Session via handshake") + if sess.session_id == "default": + client.sess.session_id = str(uuid.uuid4()) + LOG.debug(f"Client session_id randomly generated: {client.sess.session_id}") else: - self.clients[client.peer] = client + client.sess.session_id = sess.session_id + LOG.debug(f"Client session_id assigned via client first message: {client.sess.session_id}") + self.clients[client.peer] = client + + if sess.session_id == "default": + sess.session_id = client.sess.session_id + if client.sess.session_id == sess.session_id: + client.sess = sess + LOG.debug(f"Client session updated from payload: {sess.serialize()}") self.handle_inject_mycroft_msg(message.payload, client) if self.mycroft_bus_callback: @@ -672,8 +688,8 @@ def handle_inject_mycroft_msg( return # ensure client specific session data is injected in query to ovos - if "session" not in message.context: - message.context["session"] = client.sess.serialize() + LOG.debug("replacing message metadata with hivemind client session") + message.context["session"] = client.sess.serialize() if message.msg_type == "speak": message.context["destination"] = ["audio"] # make audible, this is injected "speak" command elif message.context.get("destination") is None: