Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

fix/session_mapping #94

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import uuid
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from typing import List, Dict, Optional
Expand Down Expand Up @@ -260,7 +261,6 @@ def get_bus(self, client: HiveMindClientConnection):

def handle_new_client(self, client: HiveMindClientConnection):
LOG.debug(f"new client: {client.peer}")
self.clients[client.peer] = client
message = Message(
"hive.client.connect",
{"ip": client.ip, "session_id": client.sess.session_id},
Expand Down Expand Up @@ -398,6 +398,8 @@ def handle_handshake_message(
):
LOG.debug("handshake received, generating session key")
payload = message.payload
if "session" in payload:
client.sess = Session.deserialize(payload["session"])
if "site_id" in payload:
client.sess.site_id = client.site_id = payload["site_id"]
if "pubkey" in payload and client.handshake is not None:
Expand Down Expand Up @@ -445,22 +447,36 @@ def handle_handshake_message(
client.socket.close()
return

LOG.debug(f"client site_id: {client.sess.site_id}")
if client.sess.session_id != "default":
LOG.debug(f"client session_id: {client.sess.session_id}")
self.clients[client.peer] = client
else:
LOG.warning("client did not send a session in it's handshake")

msg = HiveMessage(HiveMessageType.HANDSHAKE, payload)
client.send(msg) # client can recreate crypto_key on his side now

def handle_bus_message(
self, message: HiveMessage, client: HiveMindClientConnection
):
# update the session as received by the client
old = client.peer
client.sess = Session.from_message(message.payload)
LOG.debug(f"Client session updated: {client.sess.serialize()}")
if old != client.peer:
LOG.debug(f"Client session_id changed! new peer_id: {client.peer}")
if old in self.clients:
self.clients[client.peer] = self.clients.pop(old)
# track any Session updates from client side
sess = Session.from_message(message.payload)
if client.sess.session_id == "default":
LOG.warning(f"{client.peer} did not send a Session via handshake")
if sess.session_id == "default":
client.sess.session_id = str(uuid.uuid4())
LOG.debug(f"Client session_id randomly generated: {client.sess.session_id}")
else:
self.clients[client.peer] = client
client.sess.session_id = sess.session_id
LOG.debug(f"Client session_id assigned via client first message: {client.sess.session_id}")
self.clients[client.peer] = client

if sess.session_id == "default":
sess.session_id = client.sess.session_id
if client.sess.session_id == sess.session_id:
client.sess = sess
LOG.debug(f"Client session updated from payload: {sess.serialize()}")

self.handle_inject_mycroft_msg(message.payload, client)
if self.mycroft_bus_callback:
Expand Down Expand Up @@ -672,8 +688,8 @@ def handle_inject_mycroft_msg(
return

# ensure client specific session data is injected in query to ovos
if "session" not in message.context:
message.context["session"] = client.sess.serialize()
LOG.debug("replacing message metadata with hivemind client session")
message.context["session"] = client.sess.serialize()
if message.msg_type == "speak":
message.context["destination"] = ["audio"] # make audible, this is injected "speak" command
elif message.context.get("destination") is None:
Expand Down
Loading