Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

Commit

Permalink
feat:binary handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Oct 30, 2024
1 parent b47e078 commit de25aad
Showing 1 changed file with 80 additions and 12 deletions.
92 changes: 80 additions & 12 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tornado.websocket import WebSocketHandler

from hivemind_bus_client.identity import NodeIdentity
from hivemind_bus_client.message import HiveMessage, HiveMessageType
from hivemind_bus_client.message import HiveMessage, HiveMessageType, HiveMindBinaryPayloadType
from hivemind_bus_client.serialization import decode_bitstring, get_bitstring
from hivemind_bus_client.util import (
decrypt_bin,
Expand Down Expand Up @@ -111,13 +111,16 @@ def send(self, message: HiveMessage):
HiveMessageType.HANDSHAKE,
HiveMessageType.HELLO,
]:
if self.binarize:
payload = get_bitstring(message.msg_type, message.payload).bytes
if self.binarize or is_bin:
payload = get_bitstring(hive_type=message.msg_type,
payload=message.payload,
hivemeta=message.metadata,
binary_type=message.bin_type).bytes
payload = encrypt_bin(self.crypto_key, payload)
is_bin = True
else:
payload = encrypt_as_json(
self.crypto_key, message.serialize() # json string
self.crypto_key, message.serialize() # json string
) # json string
LOG.debug(f"encrypted payload: {len(payload)}")
else:
Expand Down Expand Up @@ -388,7 +391,7 @@ def handle_message(self, message: HiveMessage, client: HiveMindClientConnection)

# HiveMind protocol messages - from slave -> master
def handle_unknown_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""message handler for non default message types, subclasses can
handle their own types here
Expand All @@ -397,13 +400,78 @@ def handle_unknown_message(
"""

def handle_binary_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
assert message.msg_type == HiveMessageType.BINARY
# TODO
bin_data = message.payload
if message.bin_type == HiveMindBinaryPayloadType.RAW_AUDIO:
sr = message.metadata.get("sample_rate", 16000)
sw = message.metadata.get("sample_width", 2)
self.handle_microphone_input(bin_data, sr, sw, client)
elif message.bin_type == HiveMindBinaryPayloadType.STT_AUDIO_TRANSCRIBE:
lang = message.metadata.get("lang")
sr = message.metadata.get("sample_rate", 16000)
sw = message.metadata.get("sample_width", 2)
self.handle_stt_transcribe_request(bin_data, sr, sw, lang, client)
elif message.bin_type == HiveMindBinaryPayloadType.STT_AUDIO_HANDLE:
lang = message.metadata.get("lang")
sr = message.metadata.get("sample_rate", 16000)
sw = message.metadata.get("sample_width", 2)
self.handle_stt_handle_request(bin_data, sr, sw, lang, client)
elif message.bin_type == HiveMindBinaryPayloadType.TTS_AUDIO:
lang = message.metadata.get("lang")
utt = message.metadata.get("utterance")
file_name = message.metadata.get("file_name")
self.handle_receive_tts(bin_data, utt, lang, file_name, client)
elif message.bin_type == HiveMindBinaryPayloadType.FILE:
file_name = message.metadata.get("file_name")
self.handle_receive_file(bin_data, file_name, client)
elif message.bin_type == HiveMindBinaryPayloadType.NUMPY_IMAGE:
# TODO - convert to numpy array
camera_id = message.metadata.get("camera_id")
self.handle_numpy_image(bin_data, camera_id, client)
else:
LOG.warning(f"Ignoring received untyped binary data: {len(bin_data)} bytes")

def handle_microphone_input(self, bin_data: bytes,
sample_rate: int,
sample_width: int,
client: HiveMindClientConnection):
LOG.warning(f"Ignoring received binary audio input: {len(bin_data)} bytes at sample_rate: {sample_rate}")

def handle_stt_transcribe_request(self, bin_data: bytes,
sample_rate: int,
sample_width: int,
lang: str,
client: HiveMindClientConnection):
LOG.warning(f"Ignoring received binary STT input: {len(bin_data)} bytes")

def handle_stt_handle_request(self, bin_data: bytes,
sample_rate: int,
sample_width: int,
lang: str,
client: HiveMindClientConnection):
LOG.warning(f"Ignoring received binary STT input: {len(bin_data)} bytes")

def handle_numpy_image(self, bin_data: bytes,
camera_id: str,
client: HiveMindClientConnection):
LOG.warning(f"Ignoring received binary image: {len(bin_data)} bytes")

def handle_receive_tts(self, bin_data: bytes,
utterance: str,
lang: str,
file_name: str,
client: HiveMindClientConnection):
LOG.warning(f"Ignoring received binary TTS audio: {utterance} with {len(bin_data)} bytes")

def handle_receive_file(self, bin_data: bytes,
file_name: str,
client: HiveMindClientConnection):
LOG.warning(f"Ignoring received binary file: {file_name} with {len(bin_data)} bytes")

def handle_handshake_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
LOG.debug("handshake received, generating session key")
payload = message.payload
Expand Down Expand Up @@ -467,7 +535,7 @@ def handle_handshake_message(
client.send(msg) # client can recreate crypto_key on his side now

def handle_bus_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
# track any Session updates from client side
sess = Session.from_message(message.payload)
Expand All @@ -492,7 +560,7 @@ def handle_bus_message(
self.mycroft_bus_callback(message.payload)

def handle_broadcast_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""
message (HiveMessage): HiveMind message object
Expand Down Expand Up @@ -536,7 +604,7 @@ def _unpack_message(self, message: HiveMessage, client: HiveMindClientConnection
return pload

def handle_propagate_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""
message (HiveMessage): HiveMind message object
Expand Down Expand Up @@ -587,7 +655,7 @@ def handle_propagate_message(
bus.emit(message)

def handle_escalate_message(
self, message: HiveMessage, client: HiveMindClientConnection
self, message: HiveMessage, client: HiveMindClientConnection
):
"""
message (HiveMessage): HiveMind message object
Expand Down

0 comments on commit de25aad

Please sign in to comment.