Skip to content

Commit

Permalink
feat: add support for background command executes
Browse files Browse the repository at this point in the history
* chore: add support for bg as true

* chore: add timeout support in command
  • Loading branch information
rrkumarshikhar authored Nov 22, 2024
1 parent 3f379ee commit db34d48
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
40 changes: 39 additions & 1 deletion rapyuta_io/clients/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import subprocess

from six.moves.urllib.parse import urlencode
from time import sleep
import enum

import requests
Expand Down Expand Up @@ -567,7 +568,44 @@ def execute_command(self, command, retry_limit=0):
if response.status_code == requests.codes.BAD_REQUEST:
raise ParameterMissingException(get_error(response.text))
execution_result = get_api_response_data(response)
return execution_result[self.uuid]
if not command.bg:
return execution_result[self.uuid]
jid = execution_result.get('jid')
if not jid:
raise ValueError("Job ID not found in the response")
return self.fetch_command_result(jid, [self.uuid], timeout=command.timeout)

def fetch_command_result(self, jid: str, deviceids: list, timeout: int):
"""
Fetch the result of the command execution using the job ID (jid) and the first device ID from the list.
Args:
jobid (str): The job ID of the executed command.
deviceids (list): A list of device IDs on which the command was executed.
timeout (int): The maximum time to wait for the result (in seconds). Default is 300 seconds.
Returns:
dict: The result of the command execution.
Raises:
TimeoutError: If the result is not available within the timeout period.
APIError: If the API returns an error.
"""

if not deviceids or not isinstance(deviceids, list):
raise ValueError("Device IDs must be provided as a non-empty list.")
url = self._device_api_host + DEVICE_COMMAND_API_PATH + "jobid"
payload = {
"jid": jid,
"device_id": deviceids[0]
}
total_time_waited = 0
wait_interval = 10
while total_time_waited < timeout:
response = self._execute_api(url, HttpMethod.POST, payload)
if response.status_code == requests.codes.OK:
result = get_api_response_data(response)
return result[deviceids[0]]
sleep(wait_interval)
total_time_waited += wait_interval
raise TimeoutError(f"Command result not available after {timeout} seconds")

def get_config_variables(self):
"""
Expand Down
5 changes: 4 additions & 1 deletion rapyuta_io/clients/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Command(ObjDict):
"""

def __init__(self, cmd, shell=None, env=None, bg=False, runas=None, pwd=None, cwd=None):
def __init__(self, cmd, shell=None, env=None, bg=False, runas=None, pwd=None, cwd=None, timeout=300):
super(ObjDict, self).__init__()
if env is None:
env = dict()
Expand All @@ -73,6 +73,7 @@ def __init__(self, cmd, shell=None, env=None, bg=False, runas=None, pwd=None, cw
self.cwd = pwd
if cwd is not None:
self.cwd = cwd
self.timeout = timeout
self.validate()

def validate(self):
Expand All @@ -93,6 +94,8 @@ def validate(self):
raise InvalidCommandException('Invalid environment variables')
return
raise InvalidCommandException('Invalid environment variables')
if self.timeout <= 0:
raise InvalidCommandException("Invalid timeout value")

def to_json(self):
# TODO: we need to rewrite this function.
Expand Down

0 comments on commit db34d48

Please sign in to comment.