Skip to content

Commit

Permalink
feat: add support for custom actions
Browse files Browse the repository at this point in the history
  • Loading branch information
benthomasson committed Jan 18, 2024
1 parent 37344da commit 3456e39
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 75 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,5 @@ ENV/

# awx provision
tests/e2e/utils/awx/artifacts

.DS_Store
5 changes: 5 additions & 0 deletions ansible_rulebook/action/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ansible_rulebook.action.control import Control
from ansible_rulebook.action.helper import Helper
from ansible_rulebook.action.metadata import Metadata

__all__ = ["Control", "Helper", "Metadata"]
5 changes: 4 additions & 1 deletion ansible_rulebook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ async def run(parsed_args: argparse.Namespace) -> None:
startup_args.controller_url = parsed_args.controller_url
startup_args.controller_token = parsed_args.controller_token
startup_args.controller_ssl_verify = parsed_args.controller_ssl_verify
startup_args.source_dir = parsed_args.source_dir
startup_args.action_dir = parsed_args.action_dir

validate_actions(startup_args)

Expand All @@ -107,7 +109,7 @@ async def run(parsed_args: argparse.Namespace) -> None:
tasks, ruleset_queues = spawn_sources(
startup_args.rulesets,
startup_args.variables,
[parsed_args.source_dir],
[startup_args.source_dir],
parsed_args.shutdown_delay,
)

Expand All @@ -128,6 +130,7 @@ async def run(parsed_args: argparse.Namespace) -> None:
parsed_args,
startup_args.project_data_file,
file_monitor,
[startup_args.action_dir],
)

await event_log.put(dict(type="Exit"))
Expand Down
5 changes: 5 additions & 0 deletions ansible_rulebook/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def get_parser() -> argparse.ArgumentParser:
"--source-dir",
help="Source dir",
)
parser.add_argument(
"-A",
"--action-dir",
help="Action dir",
)
parser.add_argument(
"-i",
"--inventory",
Expand Down
21 changes: 21 additions & 0 deletions ansible_rulebook/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@

EDA_YAML_EXTENSIONS = [".yml", ".yaml"]

EDA_ACTION_PATHS = [
f"{EDA_PATH_PREFIX}/plugins/rule_action",
]
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -138,6 +141,24 @@ def load_rulebook(collection, rulebook):
return yaml.safe_load(f.read())


def has_action(collection, action):
return has_object(
collection,
action,
EDA_ACTION_PATHS,
".py",
)


def find_action(collection, action):
return find_object(
collection,
action,
EDA_ACTION_PATHS,
".py",
)


def has_source(collection, source):
return has_object(
collection,
Expand Down
2 changes: 2 additions & 0 deletions ansible_rulebook/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ class StartupArgs:
project_data_file: str = field(default="")
inventory: str = field(default="")
check_controller_connection: bool = field(default=False)
source_dir: str = field(default="")
action_dir: str = field(default="")
2 changes: 2 additions & 0 deletions ansible_rulebook/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ async def run_rulesets(
parsed_args: argparse.Namespace = None,
project_data_file: Optional[str] = None,
file_monitor: str = None,
action_directories: Optional[List[str]] = None,
) -> bool:
logger.debug("run_ruleset")
rulesets_queue_plans = rule_generator.generate_rulesets(
Expand Down Expand Up @@ -305,6 +306,7 @@ async def run_rulesets(
project_data_file=project_data_file,
parsed_args=parsed_args,
broadcast_method=broadcast,
action_directories=action_directories,
)
task_name = f"main_ruleset :: {ruleset_queue_plan.ruleset.name}"
ruleset_task = asyncio.create_task(
Expand Down
216 changes: 144 additions & 72 deletions ansible_rulebook/rule_set_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import asyncio
import gc
import logging
import os
import runpy
import uuid
from pprint import pformat
from types import MappingProxyType
Expand Down Expand Up @@ -42,6 +44,11 @@
from ansible_rulebook.action.run_workflow_template import RunWorkflowTemplate
from ansible_rulebook.action.set_fact import SetFact
from ansible_rulebook.action.shutdown import Shutdown as ShutdownAction
from ansible_rulebook.collection import (
find_action,
has_action,
split_collection_name,
)
from ansible_rulebook.conf import settings
from ansible_rulebook.exception import (
ShutdownException,
Expand Down Expand Up @@ -89,6 +96,7 @@ def __init__(
project_data_file: Optional[str] = None,
parsed_args=None,
broadcast_method=None,
action_directories=None,
):
self.action_loop_task = None
self.event_log = event_log
Expand All @@ -104,6 +112,14 @@ def __init__(
self.broadcast_method = broadcast_method
self.event_counter = 0
self.display = terminal.Display()
self.action_directories = action_directories

def find_action(self, action: str):
for action_dir in self.action_directories:
action_plugin_file = os.path.join(action_dir, f"{action}.py")
if os.path.exists(action_plugin_file):
return runpy.run_path(action_plugin_file)
return None

async def run_ruleset(self):
tasks = []
Expand Down Expand Up @@ -352,6 +368,89 @@ def _run_action(
task.add_done_callback(self._handle_action_completion)
return task

def _build_control(
self,
action,
action_args,
rules_engine_result,
variables,
metadata,
inventory,
hosts,
):
if action == "run_job_template" or action == "run_workflow_template":
limit = dpath.get(
action_args,
"job_args.limit",
separator=".",
default=None,
)
if isinstance(limit, list):
hosts = limit
elif isinstance(limit, str):
hosts = [limit]
elif action == "shutdown":
if self.parsed_args and "delay" not in action_args:
action_args["delay"] = self.parsed_args.shutdown_delay

single_match = None
keys = list(rules_engine_result.data.keys())
if len(keys) == 0:
single_match = {}
elif len(keys) == 1 and keys[0] == "m":
single_match = rules_engine_result.data[keys[0]]
else:
multi_match = rules_engine_result.data
variables_copy = variables.copy()
if single_match is not None:
variables_copy["event"] = single_match
event = single_match
if "meta" in event:
if "hosts" in event["meta"]:
hosts = parse_hosts(event["meta"]["hosts"])
else:
variables_copy["events"] = multi_match
new_hosts = []
for event in variables_copy["events"].values():
if "meta" in event:
if "hosts" in event["meta"]:
new_hosts.extend(parse_hosts(event["meta"]["hosts"]))
if new_hosts:
hosts = new_hosts

if "var_root" in action_args:
var_root = action_args.pop("var_root")
logger.debug(
"Update variables [%s] with new root [%s]",
variables_copy,
var_root,
)
_update_variables(variables_copy, var_root)

logger.debug(
"substitute_variables [%s] [%s]",
action_args,
variables_copy,
)
action_args = {
k: substitute_variables(v, variables_copy)
for k, v in action_args.items()
}
logger.debug("action args: %s", action_args)

if "ruleset" not in action_args:
action_args["ruleset"] = metadata.rule_set

control = Control(
queue=self.event_log,
inventory=inventory,
hosts=hosts,
variables=variables_copy,
project_data_file=self.project_data_file,
)

return control, action_args, variables_copy

async def _call_action(
self,
metadata: Metadata,
Expand All @@ -368,80 +467,15 @@ async def _call_action(
error = None
if action in ACTION_CLASSES:
try:
if (
action == "run_job_template"
or action == "run_workflow_template"
):
limit = dpath.get(
action_args,
"job_args.limit",
separator=".",
default=None,
)
if isinstance(limit, list):
hosts = limit
elif isinstance(limit, str):
hosts = [limit]
elif action == "shutdown":
if self.parsed_args and "delay" not in action_args:
action_args["delay"] = self.parsed_args.shutdown_delay

single_match = None
keys = list(rules_engine_result.data.keys())
if len(keys) == 0:
single_match = {}
elif len(keys) == 1 and keys[0] == "m":
single_match = rules_engine_result.data[keys[0]]
else:
multi_match = rules_engine_result.data
variables_copy = variables.copy()
if single_match is not None:
variables_copy["event"] = single_match
event = single_match
if "meta" in event:
if "hosts" in event["meta"]:
hosts = parse_hosts(event["meta"]["hosts"])
else:
variables_copy["events"] = multi_match
new_hosts = []
for event in variables_copy["events"].values():
if "meta" in event:
if "hosts" in event["meta"]:
new_hosts.extend(
parse_hosts(event["meta"]["hosts"])
)
if new_hosts:
hosts = new_hosts

if "var_root" in action_args:
var_root = action_args.pop("var_root")
logger.debug(
"Update variables [%s] with new root [%s]",
variables_copy,
var_root,
)
_update_variables(variables_copy, var_root)

logger.debug(
"substitute_variables [%s] [%s]",
control, action_args, variables_copy = self._build_control(
action,
action_args,
variables_copy,
)
action_args = {
k: substitute_variables(v, variables_copy)
for k, v in action_args.items()
}
logger.debug("action args: %s", action_args)

if "ruleset" not in action_args:
action_args["ruleset"] = metadata.rule_set

control = Control(
queue=self.event_log,
inventory=inventory,
hosts=hosts,
variables=variables_copy,
project_data_file=self.project_data_file,
rules_engine_result,
variables,
metadata,
inventory,
hosts,
)

await ACTION_CLASSES[action](
Expand Down Expand Up @@ -478,6 +512,44 @@ async def _call_action(
except Exception as e:
logger.error("Error calling action %s, err %s", action, str(e))
error = e
raise
except BaseException as e:
logger.error(e)
raise
elif action_plugin := self.find_action(action):
try:
control, action_args, variables_copy = self._build_control(
action,
action_args,
rules_engine_result,
variables,
metadata,
inventory,
hosts,
)
await action_plugin["main"](metadata, control, **action_args)
except Exception as e:
logger.error("Error calling action %s, err %s", action, str(e))
raise
except BaseException as e:
logger.error(e)
raise
elif has_action(*split_collection_name(action)):
action_plugin = find_action(*split_collection_name(action))
try:
control, action_args, variables_copy = self._build_control(
action,
action_args,
rules_engine_result,
variables,
metadata,
inventory,
hosts,
)
await action_plugin.main(metadata, control, **action_args)
except Exception as e:
logger.error("Error calling action %s, err %s", action, str(e))
raise
except BaseException as e:
logger.error(e)
raise
Expand Down
Loading

0 comments on commit 3456e39

Please sign in to comment.