Skip to content

Commit

Permalink
raise from client when attempting to subscribe or unsubscribe while c…
Browse files Browse the repository at this point in the history
…onnection is not running
  • Loading branch information
deeleeramone committed Dec 1, 2024
1 parent 161398a commit 020becf
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ async def read_stdin(broadcast_server):
if line.strip().startswith("{") or line.strip().startswith("[")
else line.strip()
)
msg = (
"BROADCAST INFO: Message received from parent process and relayed to active listeners ->"
+ f" {json.dumps(command)}"
)
await broadcast_server.broadcast(json.dumps(command))
broadcast_server.logger.info(msg)
except json.JSONDecodeError:
broadcast_server.logger.error("Invalid JSON received from stdin")

Expand Down Expand Up @@ -270,10 +265,12 @@ def main():
**kwargs,
)
except TypeError as e:
msg = f"Invalid keyword argument passed to unvicorn. -> {e.args[0]}\n"
msg = (
f"ERROR: Invalid keyword argument passed to unvicorn. -> {e.args[0]}\n"
)
broadcast_server.logger.error(msg)
except KeyboardInterrupt:
broadcast_server.logger.info("Broadcast server terminated.")
broadcast_server.logger.info("INFO: Broadcast server terminated.")
finally:
sys.exit(0)

Expand Down
42 changes: 27 additions & 15 deletions openbb_platform/extensions/websockets/openbb_websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def subscribe(self, symbol) -> None:
import time
from openbb_core.app.model.abstract.error import OpenBBError

if not self.is_running:
raise OpenBBError("Provider connection is not running.")

ticker = symbol if isinstance(symbol, list) else symbol.split(",")
msg = {"event": "subscribe", "symbol": ticker}
self.send_message(json.dumps(msg))
Expand All @@ -492,15 +495,21 @@ def subscribe(self, symbol) -> None:
def unsubscribe(self, symbol) -> None:
"""Unsubscribe from a symbol or list of symbols."""
# pylint: disable=import-outside-toplevel
import json
import json # noqa
import time
from openbb_core.app.model.abstract.error import OpenBBError

if not self.symbol:
self.logger.info("No subscribed symbols.")
return

if not self.is_running:
raise OpenBBError("Provider connection is not running.")

ticker = symbol if isinstance(symbol, list) else symbol.split(",")
msg = {"event": "unsubscribe", "symbol": ticker}
self.send_message(json.dumps(msg))
time.sleep(0.1)
old_symbols = self.symbol.split(",")
new_symbols = list(set(old_symbols) - set(ticker))
self._symbol = ",".join(new_symbols)
Expand Down Expand Up @@ -782,22 +791,25 @@ def read_message_queue(
"""Read messages from the queue and send them to the WebSocketConnection process."""
while not message_queue.empty():
try:
if target == "provider":
while not client._stop_log_thread_event.is_set():
message = message_queue.get(timeout=1)
if message:
message = message_queue.get(timeout=1)
if message:
try:
if (
target == "provider"
and not client._stop_log_thread_event.is_set()
):
send_message(client, message, target="provider")
elif target == "broadcast":
while not client._stop_broadcasting_event.is_set():
message = message_queue.get(timeout=1)
if message:
elif (
target == "broadcast"
and not client._stop_broadcasting_event.is_set()
):
send_message(client, message, target="broadcast")
except Exception as e:
err = (
f"Error while attempting to transmit from the outgoing message queue: {e.__class__.__name__} "
f"-> {e} -> {message}"
)
client.logger.error(err)
except Exception as e:
err = (
f"Error while attempting to transmit from the outgoing message queue: {e.__class__.__name__} "
f"-> {e} -> {message}"
)
client.logger.error(err)
finally:
break

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ async def setup_database(results_path, table_name):
import os # noqa
import aiosqlite

async with aiosqlite.connect(results_path) as conn:
async with aiosqlite.connect(results_path, check_same_thread=False) as conn:
if os.path.exists(results_path):
try:
await conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
except aiosqlite.DatabaseError:
os.remove(results_path)

async with aiosqlite.connect(results_path) as conn:
async with aiosqlite.connect(results_path, check_same_thread=False) as conn:
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_name} (
Expand All @@ -194,7 +194,7 @@ async def write_to_db(message, results_path, table_name, limit):
import json # noqa
import aiosqlite

conn = await aiosqlite.connect(results_path)
conn = await aiosqlite.connect(results_path, check_same_thread=False)

# Check if the table exists and create it if it doesn't
await conn.execute(
Expand Down

0 comments on commit 020becf

Please sign in to comment.