Skip to content

Commit

Permalink
Add safety around missing ssh drivers.
Browse files Browse the repository at this point in the history
  • Loading branch information
eseglem committed Nov 20, 2023
1 parent 4ec6344 commit d1595f5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pywattbox"
version = "0.7.1"
version = "0.7.2"
description = "A python wrapper for WattBox APIs."
license = "MIT"
readme = "README.md"
Expand Down
35 changes: 27 additions & 8 deletions pywattbox/ip_wattbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Union,
)

from scrapli.exceptions import ScrapliTransportPluginError
from scrapli.response import Response

from .base import BaseWattBox, Commands, Outlet, _async_create_wattbox, _create_wattbox
Expand Down Expand Up @@ -100,6 +101,10 @@ class UpdateBaseResponses(NamedTuple):
_Responses = TypeVar("_Responses", bound=Union[InitialResponses, UpdateBaseResponses])


class DriverUnavailableError(Exception):
pass


class IpWattBox(BaseWattBox):
def __init__(
self,
Expand Down Expand Up @@ -129,18 +134,26 @@ def __init__(
else:
raise ValueError("Non Standard Port, Transport must be set.")

self.driver = WattBoxDriver(
**conninfo,
transport="ssh2" if transport == "ssh" else "telnet",
)
self.async_driver = WattBoxAsyncDriver(
**conninfo,
transport="asyncssh" if transport == "ssh" else "asynctelnet",
)
try:
self.driver: Optional[WattBoxDriver] = WattBoxDriver(
**conninfo,
transport="ssh2" if transport == "ssh" else "telnet",
)
except ScrapliTransportPluginError:
self.driver = None
try:
self.async_driver: Optional[WattBoxAsyncDriver] = WattBoxAsyncDriver(
**conninfo,
transport="asyncssh" if transport == "ssh" else "asynctelnet",
)
except ScrapliTransportPluginError:
self.async_driver = None

def send_requests(
self, requests: Iterable[Union[REQUEST_MESSAGES, str]]
) -> List[Response]:
if not self.driver:
raise DriverUnavailableError
responses: List[Response] = []
for request in requests:
responses.append(
Expand All @@ -153,6 +166,8 @@ def send_requests(
async def async_send_requests(
self, requests: Iterable[Union[REQUEST_MESSAGES, str]]
) -> List[Response]:
if not self.async_driver:
raise DriverUnavailableError
responses: List[Response] = []
for request in requests:
responses.append(
Expand Down Expand Up @@ -266,6 +281,8 @@ async def async_update(self) -> None:

def send_command(self, outlet: int, command: Commands) -> None:
logger.debug("Send Command")
if not self.driver:
raise DriverUnavailableError
self.driver._send_command(
CONTROL_MESSAGES.OUTLET_SET.value.format(
outlet=outlet, action=command.name, delay=0
Expand All @@ -275,6 +292,8 @@ def send_command(self, outlet: int, command: Commands) -> None:

async def async_send_command(self, outlet: int, command: Commands) -> None:
logger.debug("Async Send Command")
if not self.async_driver:
raise DriverUnavailableError
await self.async_driver._send_command(
CONTROL_MESSAGES.OUTLET_SET.value.format(
outlet=outlet, action=command.name, delay=0
Expand Down

0 comments on commit d1595f5

Please sign in to comment.