From 5c4ab0f44c490d1cedf97fddfe69db3aee03b5f2 Mon Sep 17 00:00:00 2001 From: Shahar Glazner Date: Fri, 27 Dec 2024 10:31:52 +0200 Subject: [PATCH] feat(api): memory consumption and performance improvements (#2908) Signed-off-by: Shahar Glazner Signed-off-by: Tal Co-authored-by: Tal --- .gitignore | 2 + .../alert_deduplicator/alert_deduplicator.py | 12 +- keep/api/core/demo_mode.py | 7 +- keep/api/core/metrics.py | 54 +++- keep/api/logging.py | 15 +- keep/api/observability.py | 34 ++- keep/api/routes/alerts.py | 93 ++++--- keep/cli/cli.py | 119 -------- keep/workflowmanager/workflowmanager.py | 41 +-- keep/workflowmanager/workflowscheduler.py | 257 ++++++++++++------ poetry.lock | 21 +- pyproject.toml | 4 +- tests/test_workflow_execution.py | 33 ++- tests/test_workflowmanager.py | 95 ++++--- 14 files changed, 419 insertions(+), 368 deletions(-) diff --git a/.gitignore b/.gitignore index 6bcb9b6e7..38d405e05 100644 --- a/.gitignore +++ b/.gitignore @@ -214,3 +214,5 @@ oauth2.cfg scripts/keep_slack_bot.py keepnew.db providers_cache.json + +tests/provision/* diff --git a/keep/api/alert_deduplicator/alert_deduplicator.py b/keep/api/alert_deduplicator/alert_deduplicator.py index 09e30ab9f..57b4e12d0 100644 --- a/keep/api/alert_deduplicator/alert_deduplicator.py +++ b/keep/api/alert_deduplicator/alert_deduplicator.py @@ -25,7 +25,6 @@ DeduplicationRuleRequestDto, ) from keep.providers.providers_factory import ProvidersFactory -from keep.searchengine.searchengine import SearchEngine DEFAULT_RULE_UUID = "00000000-0000-0000-0000-000000000000" @@ -42,7 +41,6 @@ class AlertDeduplicator: def __init__(self, tenant_id): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id - self.search_engine = SearchEngine(self.tenant_id) def _apply_deduplication_rule( self, alert: AlertDto, rule: DeduplicationRuleDto @@ -264,7 +262,7 @@ def _get_default_full_deduplication_rule( ingested=0, dedup_ratio=0.0, enabled=True, - is_provisioned=False + is_provisioned=False, ) def get_deduplications(self) -> list[DeduplicationRuleDto]: @@ -502,15 +500,15 @@ def update_deduplication_rule( rule_dto = self.create_deduplication_rule(rule, updated_by) self.logger.info("Default rule updated") return rule_dto - + rule_before_update = get_deduplication_rule_by_id(self.tenant_id, rule_id) - + if not rule_before_update: raise HTTPException( status_code=404, detail="Deduplication rule not found", ) - + if rule_before_update.is_provisioned: raise HTTPException( status_code=409, @@ -557,7 +555,7 @@ def delete_deduplication_rule(self, rule_id: str) -> bool: status_code=404, detail="Deduplication rule not found", ) - + if deduplication_rule_to_be_deleted.is_provisioned: raise HTTPException( status_code=409, diff --git a/keep/api/core/demo_mode.py b/keep/api/core/demo_mode.py index e18ea90db..4ea72a1ab 100644 --- a/keep/api/core/demo_mode.py +++ b/keep/api/core/demo_mode.py @@ -568,7 +568,6 @@ async def simulate_alerts_async( logger.info( "Sleeping for {} seconds before next iteration".format(sleep_interval) ) - await asyncio.sleep(sleep_interval) def launch_demo_mode_thread( @@ -623,11 +622,14 @@ async def simulate_alerts_worker(worker_id, keep_api_key, rps=1): url, alert = await REQUESTS_QUEUE.get() async with session.post(url, json=alert, headers=headers) as response: + response_time = time.time() - start total_requests += 1 if not response.ok: logger.error("Failed to send alert: {}".format(response.text)) else: - logger.info("Alert sent successfully") + logger.info( + f"Alert sent successfully in {response_time:.3f} seconds" + ) if rps: delay = 1 / rps - (time.time() - start) @@ -639,6 +641,7 @@ async def simulate_alerts_worker(worker_id, keep_api_key, rps=1): worker_id, total_requests / (time.time() - total_start), ) + logger.info("Total requests: %d", total_requests) if __name__ == "__main__": diff --git a/keep/api/core/metrics.py b/keep/api/core/metrics.py index 8ff667855..df2ab8729 100644 --- a/keep/api/core/metrics.py +++ b/keep/api/core/metrics.py @@ -1,6 +1,6 @@ import os -from prometheus_client import Counter, Gauge, Summary +from prometheus_client import Counter, Gauge, Histogram, Summary PROMETHEUS_MULTIPROC_DIR = os.environ.get("PROMETHEUS_MULTIPROC_DIR", "/tmp/prometheus") os.makedirs(PROMETHEUS_MULTIPROC_DIR, exist_ok=True) @@ -37,3 +37,55 @@ labelnames=["pid"], multiprocess_mode="livesum", ) + +### WORKFLOWS +METRIC_PREFIX = "keep_workflows_" + +# Workflow execution metrics +workflow_executions_total = Counter( + f"{METRIC_PREFIX}executions_total", + "Total number of workflow executions", + labelnames=["tenant_id", "workflow_id", "trigger_type"], +) + +workflow_execution_errors_total = Counter( + f"{METRIC_PREFIX}execution_errors_total", + "Total number of workflow execution errors", + labelnames=["tenant_id", "workflow_id", "error_type"], +) + +workflow_execution_status = Counter( + f"{METRIC_PREFIX}execution_status_total", + "Total number of workflow executions by status", + labelnames=["tenant_id", "workflow_id", "status"], +) + +# Workflow performance metrics +workflow_execution_duration = Histogram( + f"{METRIC_PREFIX}execution_duration_seconds", + "Time spent executing workflows", + labelnames=["tenant_id", "workflow_id"], + buckets=(1, 5, 10, 30, 60, 120, 300, 600), # 1s, 5s, 10s, 30s, 1m, 2m, 5m, 10m +) + +workflow_execution_step_duration = Histogram( + f"{METRIC_PREFIX}execution_step_duration_seconds", + "Time spent executing individual workflow steps", + labelnames=["tenant_id", "workflow_id", "step_name"], + buckets=(0.1, 0.5, 1, 2, 5, 10, 30, 60), +) + +# Workflow state metrics +workflows_running = Gauge( + f"{METRIC_PREFIX}running", + "Number of currently running workflows", + labelnames=["tenant_id"], + multiprocess_mode="livesum", +) + +workflow_queue_size = Gauge( + f"{METRIC_PREFIX}queue_size", + "Number of workflows waiting to be executed", + labelnames=["tenant_id"], + multiprocess_mode="livesum", +) diff --git a/keep/api/logging.py b/keep/api/logging.py index 5d74adaff..a91659457 100644 --- a/keep/api/logging.py +++ b/keep/api/logging.py @@ -194,6 +194,7 @@ def process(self, msg, kwargs): LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") +KEEP_LOG_FILE = os.environ.get("KEEP_LOG_FILE") LOG_FORMAT_OPEN_TELEMETRY = "open_telemetry" LOG_FORMAT_DEVELOPMENT_TERMINAL = "dev_terminal" @@ -234,7 +235,7 @@ def format(self, record): }, "dev_terminal": { "()": DevTerminalFormatter, - "format": "%(asctime)s - %(levelname)s - %(message)s", + "format": "%(asctime)s - %(thread)s %(threadName)s %(levelname)s - %(message)s", }, }, "handlers": { @@ -369,6 +370,18 @@ def _log( def setup_logging(): + # Add file handler if KEEP_LOG_FILE is set + if KEEP_LOG_FILE: + CONFIG["handlers"]["file"] = { + "level": "DEBUG", + "formatter": ("json"), + "class": "logging.FileHandler", + "filename": KEEP_LOG_FILE, + "mode": "a", + } + # Add file handler to root logger + CONFIG["loggers"][""]["handlers"].append("file") + logging.config.dictConfig(CONFIG) uvicorn_error_logger = logging.getLogger("uvicorn.error") uvicorn_error_logger.__class__ = CustomizedUvicornLogger diff --git a/keep/api/observability.py b/keep/api/observability.py index b5aa3e0fa..1987a7702 100644 --- a/keep/api/observability.py +++ b/keep/api/observability.py @@ -23,20 +23,26 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor + def get_protocol_from_endpoint(endpoint): - parsed_url = urlparse(endpoint) - if parsed_url.scheme == "http": - return HTTPOTLPSpanExporter - elif parsed_url.scheme == "grpc": - return GRPCOTLPSpanExporter - else: - raise ValueError(f"Unsupported protocol: {parsed_url.scheme}") + parsed_url = urlparse(endpoint) + if parsed_url.scheme == "http": + return HTTPOTLPSpanExporter + elif parsed_url.scheme == "grpc": + return GRPCOTLPSpanExporter + else: + raise ValueError(f"Unsupported protocol: {parsed_url.scheme}") + def setup(app: FastAPI): logger = logging.getLogger(__name__) # Configure the OpenTelemetry SDK - service_name = os.environ.get("OTEL_SERVICE_NAME", os.environ.get("SERVICE_NAME", "keep-api")) - otlp_collector_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", os.environ.get("OTLP_ENDPOINT", False)) + service_name = os.environ.get( + "OTEL_SERVICE_NAME", os.environ.get("SERVICE_NAME", "keep-api") + ) + otlp_collector_endpoint = os.environ.get( + "OTEL_EXPORTER_OTLP_ENDPOINT", os.environ.get("OTLP_ENDPOINT", False) + ) otlp_traces_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", None) otlp_logs_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", None) otlp_metrics_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", None) @@ -45,7 +51,7 @@ def setup(app: FastAPI): resource = Resource.create({"service.name": service_name}) provider = TracerProvider(resource=resource) - + if otlp_collector_endpoint: logger.info(f"OTLP endpoint set to {otlp_collector_endpoint}") @@ -53,13 +59,13 @@ def setup(app: FastAPI): if otlp_traces_endpoint: logger.info(f"OTLP Traces endpoint set to {otlp_traces_endpoint}") SpanExporter = get_protocol_from_endpoint(otlp_traces_endpoint) - processor = BatchSpanProcessor( - SpanExporter(endpoint=otlp_traces_endpoint) - ) + processor = BatchSpanProcessor(SpanExporter(endpoint=otlp_traces_endpoint)) provider.add_span_processor(processor) if metrics_enabled.lower() == "true" and otlp_metrics_endpoint: - logger.info(f"Metrics enabled. OTLP Metrics endpoint set to {otlp_metrics_endpoint}") + logger.info( + f"Metrics enabled. OTLP Metrics endpoint set to {otlp_metrics_endpoint}" + ) reader = PeriodicExportingMetricReader( OTLPMetricExporter(endpoint=otlp_metrics_endpoint) ) diff --git a/keep/api/routes/alerts.py b/keep/api/routes/alerts.py index dbc411a58..018211a70 100644 --- a/keep/api/routes/alerts.py +++ b/keep/api/routes/alerts.py @@ -1,17 +1,17 @@ -import asyncio import base64 +import concurrent.futures import hashlib import hmac import json import logging import os import time +from concurrent.futures import Future, ThreadPoolExecutor from typing import List, Optional import celpy from arq import ArqRedis from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request -from fastapi.concurrency import run_in_threadpool from fastapi.responses import JSONResponse from pusher import Pusher @@ -53,6 +53,12 @@ logger = logging.getLogger(__name__) REDIS = os.environ.get("REDIS", "false") == "true" +EVENT_WORKERS = int(config("KEEP_EVENT_WORKERS", default=50, cast=int)) + +# Create dedicated threadpool +process_event_executor = ThreadPoolExecutor( + max_workers=EVENT_WORKERS, thread_name_prefix="process_event_worker" +) @router.get( @@ -257,43 +263,52 @@ def assign_alert( return {"status": "ok"} -def discard_task( +def discard_future( trace_id: str, - task: asyncio.Task, + future: Future, running_tasks: set, started_time: float, ): try: - running_tasks.discard(task) - running_tasks_gauge.dec() # Decrease total counter - running_tasks_by_process_gauge.labels( - pid=os.getpid() - ).dec() # Decrease process counter - - # Log any exception that occurred in the task - if task.exception(): + running_tasks.discard(future) + running_tasks_gauge.dec() + running_tasks_by_process_gauge.labels(pid=os.getpid()).dec() + + # Log any exception that occurred in the future + try: + exception = future.exception() + if exception: + logger.error( + "Task failed with exception", + extra={ + "trace_id": trace_id, + "error": str(exception), + "processing_time": time.time() - started_time, + }, + ) + else: + logger.info( + "Task completed", + extra={ + "processing_time": time.time() - started_time, + "trace_id": trace_id, + }, + ) + except concurrent.futures.CancelledError: logger.error( - "Task failed with exception", + "Task was cancelled", extra={ "trace_id": trace_id, - "error": str(task.exception()), - "processing_time": time.time() - started_time, - }, - ) - else: - logger.info( - "Task completed", - extra={ "processing_time": time.time() - started_time, - "trace_id": trace_id, }, ) + except Exception: # Make sure we always decrement both counters even if something goes wrong running_tasks_gauge.dec() running_tasks_by_process_gauge.labels(pid=os.getpid()).dec() logger.exception( - "Error in discard_task callback", + "Error in discard_future callback", extra={ "trace_id": trace_id, }, @@ -317,26 +332,24 @@ def create_process_event_task( running_tasks_by_process_gauge.labels( pid=os.getpid() ).inc() # Increase process counter - task = asyncio.create_task( - run_in_threadpool( - process_event, - {}, - tenant_id, - provider_type, - provider_id, - fingerprint, - api_key_name, - trace_id, - event, - ) + future = process_event_executor.submit( + process_event, + {}, # ctx + tenant_id, + provider_type, + provider_id, + fingerprint, + api_key_name, + trace_id, + event, ) - task.add_done_callback( - lambda task: discard_task(trace_id, task, running_tasks, started_time) + running_tasks.add(future) + future.add_done_callback( + lambda task: discard_future(trace_id, task, running_tasks, started_time) ) - bg_tasks.add_task(task) - running_tasks.add(task) + logger.info("Task added", extra={"trace_id": trace_id}) - return task.get_name() + return str(id(future)) @router.post( diff --git a/keep/cli/cli.py b/keep/cli/cli.py index de2ddec83..02f3d4def 100644 --- a/keep/cli/cli.py +++ b/keep/cli/cli.py @@ -15,14 +15,9 @@ from dotenv import find_dotenv, load_dotenv from prettytable import PrettyTable -from keep.api.core.db_on_start import try_create_single_tenant -from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.core.posthog import posthog_client -from keep.cli.click_extensions import NotRequiredIf from keep.providers.models.provider_config import ProviderScope from keep.providers.providers_factory import ProvidersFactory -from keep.workflowmanager.workflowmanager import WorkflowManager -from keep.workflowmanager.workflowstore import WorkflowStore load_dotenv(find_dotenv()) @@ -349,120 +344,6 @@ def api(multi_tenant: bool, port: int, host: str): api.run(app) -@cli.command() -@click.option( - "--alerts-directory", - "--alerts-file", - "-af", - type=click.Path(exists=True, dir_okay=True, file_okay=True), - help="The path to the alert yaml/alerts directory", -) -@click.option( - "--alert-url", - "-au", - help="A url that can be used to download an alert yaml", - cls=NotRequiredIf, - multiple=True, - not_required_if="alerts_directory", -) -@click.option( - "--interval", - "-i", - type=int, - help="When interval is set, Keep will run the alert every INTERVAL seconds", - required=False, - default=0, -) -@click.option( - "--providers-file", - "-p", - type=click.Path(exists=False), - help="The path to the providers yaml", - required=False, - default="providers.yaml", -) -@click.option( - "--tenant-id", - "-t", - help="The tenant id", - required=False, - default=SINGLE_TENANT_UUID, -) -@click.option("--api-key", help="The API key for keep's API", required=False) -@click.option( - "--api-url", - help="The URL for keep's API", - required=False, - default="https://s.keephq.dev", -) -@pass_info -def run( - info: Info, - alerts_directory: str, - alert_url: list[str], - interval: int, - providers_file, - tenant_id, - api_key, - api_url, -): - """Run a workflow.""" - logger.debug(f"Running alert in {alerts_directory or alert_url}") - posthog_client.capture( - info.random_user_id, - "keep-run-alert-started", - properties={ - "args": sys.argv, - "keep_version": KEEP_VERSION, - }, - ) - # this should be fixed - workflow_manager = WorkflowManager.get_instance() - workflow_store = WorkflowStore() - if tenant_id == SINGLE_TENANT_UUID: - try_create_single_tenant(SINGLE_TENANT_UUID) - workflows = workflow_store.get_workflows_from_path( - tenant_id, alerts_directory or alert_url, providers_file - ) - try: - workflow_manager.run(workflows) - except KeyboardInterrupt: - logger.info("Keep stopped by user, stopping the scheduler") - posthog_client.capture( - info.random_user_id, - "keep-run-stopped-by-user", - properties={ - "args": sys.argv, - "keep_version": KEEP_VERSION, - }, - ) - workflow_manager.stop() - logger.info("Scheduler stopped") - except Exception as e: - posthog_client.capture( - info.random_user_id, - "keep-run-unexpected-error", - properties={ - "args": sys.argv, - "error": str(e), - "keep_version": KEEP_VERSION, - }, - ) - logger.error(f"Error running alert {alerts_directory or alert_url}: {e}") - if info.verbose: - raise e - sys.exit(1) - posthog_client.capture( - info.random_user_id, - "keep-run-alert-finished", - properties={ - "args": sys.argv, - "keep_version": KEEP_VERSION, - }, - ) - logger.debug(f"Alert in {alerts_directory or alert_url} ran successfully") - - @cli.group() @pass_info def workflow(info: Info): diff --git a/keep/workflowmanager/workflowmanager.py b/keep/workflowmanager/workflowmanager.py index 01d1f9f16..a8f7d0a07 100644 --- a/keep/workflowmanager/workflowmanager.py +++ b/keep/workflowmanager/workflowmanager.py @@ -10,11 +10,12 @@ get_previous_alert_by_fingerprint, save_workflow_results, ) +from keep.api.core.metrics import workflow_execution_duration from keep.api.models.alert import AlertDto, AlertSeverity, IncidentDto from keep.identitymanager.identitymanagerfactory import IdentityManagerTypes from keep.providers.providers_factory import ProviderConfigurationException from keep.workflowmanager.workflow import Workflow -from keep.workflowmanager.workflowscheduler import WorkflowScheduler +from keep.workflowmanager.workflowscheduler import WorkflowScheduler, timing_histogram from keep.workflowmanager.workflowstore import WorkflowStore @@ -33,6 +34,7 @@ def __init__(self): self.debug = config("WORKFLOW_MANAGER_DEBUG", default=False, cast=bool) if self.debug: self.logger.setLevel(logging.DEBUG) + self.scheduler = WorkflowScheduler(self) self.workflow_store = WorkflowStore() self.started = False @@ -42,13 +44,18 @@ async def start(self): if self.started: self.logger.info("Workflow manager already started") return + await self.scheduler.start() self.started = True def stop(self): """Stops the workflow manager""" + if not self.started: + return self.scheduler.stop() self.started = False + # Clear the scheduler reference + self.scheduler = None def _apply_filter(self, filter_val, value): # if it's a regex, apply it @@ -333,37 +340,6 @@ def _get_event_value(self, event, filter_key): else: return getattr(event, filter_key, None) - # TODO should be fixed to support the usual CLI - def run(self, workflows: list[Workflow]): - """ - Run list of workflows. - - Args: - workflow (str): Either an workflow yaml or a directory containing workflow yamls or a list of URLs to get the workflows from. - providers_file (str, optional): The path to the providers yaml. Defaults to None. - """ - self.logger.info("Running workflow(s)") - workflows_errors = [] - # If at least one workflow has an interval, run workflows using the scheduler, - # otherwise, just run it - if any([Workflow.workflow_interval for Workflow in workflows]): - # running workflows in scheduler mode - self.logger.info( - "Found at least one workflow with an interval, running in scheduler mode" - ) - self.scheduler_mode = True - # if the workflows doesn't have an interval, set the default interval - for workflow in workflows: - workflow.workflow_interval = workflow.workflow_interval - # This will halt until KeyboardInterrupt - self.scheduler.run_workflows(workflows) - self.logger.info("Workflow(s) scheduled") - else: - # running workflows in the regular mode - workflows_errors = self._run_workflows_from_cli(workflows) - - return workflows_errors - def _check_premium_providers(self, workflow: Workflow): """ Check if the workflow uses premium providers in multi tenant mode. @@ -428,6 +404,7 @@ def _run_workflow_on_failure( }, ) + @timing_histogram(workflow_execution_duration) def _run_workflow( self, workflow: Workflow, workflow_execution_id: str, test_run=False ): diff --git a/keep/workflowmanager/workflowscheduler.py b/keep/workflowmanager/workflowscheduler.py index 2fe05e1ed..f8c916781 100644 --- a/keep/workflowmanager/workflowscheduler.py +++ b/keep/workflowmanager/workflowscheduler.py @@ -2,10 +2,10 @@ import hashlib import logging import queue -import threading import time -import typing import uuid +from concurrent.futures import ThreadPoolExecutor +from functools import wraps from threading import Lock from sqlalchemy.exc import IntegrityError @@ -16,6 +16,13 @@ from keep.api.core.db import get_enrichment, get_previous_execution_id from keep.api.core.db import get_workflow as get_workflow_db from keep.api.core.db import get_workflows_that_should_run +from keep.api.core.metrics import ( + workflow_execution_errors_total, + workflow_execution_status, + workflow_executions_total, + workflow_queue_size, + workflows_running, +) from keep.api.models.alert import AlertDto, IncidentDto from keep.api.utils.email_utils import KEEP_EMAILS_ENABLED, EmailTemplates, send_email from keep.providers.providers_factory import ProviderConfigurationException @@ -23,6 +30,7 @@ from keep.workflowmanager.workflowstore import WorkflowStore READ_ONLY_MODE = config("KEEP_READ_ONLY", default="false") == "true" +MAX_WORKERS = config("WORKFLOWS_MAX_WORKERS", default="20") class WorkflowStatus(enum.Enum): @@ -31,12 +39,40 @@ class WorkflowStatus(enum.Enum): PROVIDERS_NOT_CONFIGURED = "providers_not_configured" +def timing_histogram(histogram): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + try: + result = func(*args, **kwargs) + return result + finally: + duration = time.time() - start_time + # Try to get tenant_id and workflow_id from self + try: + tenant_id = args[1].context_manager.tenant_id + except Exception: + tenant_id = "unknown" + try: + workflow_id = args[1].workflow_id + except Exception: + workflow_id = "unknown" + histogram.labels(tenant_id=tenant_id, workflow_id=workflow_id).observe( + duration + ) + + return wrapper + + return decorator + + class WorkflowScheduler: MAX_SIZE_SIGNED_INT = 2147483647 + MAX_WORKERS = config("KEEP_MAX_WORKFLOW_WORKERS", default="20", cast=int) def __init__(self, workflow_manager): self.logger = logging.getLogger(__name__) - self.threads = [] self.workflow_manager = workflow_manager self.workflow_store = WorkflowStore() # all workflows that needs to be run due to alert event @@ -46,14 +82,29 @@ def __init__(self, workflow_manager): self.interval_enabled = ( config("WORKFLOWS_INTERVAL_ENABLED", default="true") == "true" ) + self.executor = ThreadPoolExecutor( + max_workers=self.MAX_WORKERS, + thread_name_prefix="WorkflowScheduler", + ) + self.scheduler_future = None + self.futures = set() + # Initialize metrics for queue size + self._update_queue_metrics() + + def _update_queue_metrics(self): + """Update queue size metrics""" + with self.lock: + for workflow in self.workflows_to_run: + tenant_id = workflow.get("tenant_id", "unknown") + workflow_queue_size.labels(tenant_id=tenant_id).set( + len(self.workflows_to_run) + ) async def start(self): self.logger.info("Starting workflows scheduler") # Shahar: fix for a bug in unit tests self._stop = False - thread = threading.Thread(target=self._start) - thread.start() - self.threads.append(thread) + self.scheduler_future = self.executor.submit(self._start) self.logger.info("Workflows scheduler started") def _handle_interval_workflows(self): @@ -70,13 +121,12 @@ def _handle_interval_workflows(self): self.logger.exception("Error getting workflows that should run") pass for workflow in workflows: - self.logger.debug("Running workflow on background") - workflow_execution_id = workflow.get("workflow_execution_id") tenant_id = workflow.get("tenant_id") workflow_id = workflow.get("workflow_id") + try: - workflow = self.workflow_store.get_workflow(tenant_id, workflow_id) + workflow_obj = self.workflow_store.get_workflow(tenant_id, workflow_id) except ProviderConfigurationException: self.logger.exception( "Provider configuration is invalid", @@ -91,7 +141,7 @@ def _handle_interval_workflows(self): workflow_id=workflow_id, workflow_execution_id=workflow_execution_id, status=WorkflowStatus.PROVIDERS_NOT_CONFIGURED, - error=f"Providers are not configured for workflow {workflow_id}, please configure it so Keep will be able to run it", + error=f"Providers are not configured for workflow {workflow_id}", ) continue except Exception as e: @@ -104,12 +154,16 @@ def _handle_interval_workflows(self): error=f"Error getting workflow: {e}", ) continue - thread = threading.Thread( - target=self._run_workflow, - args=[tenant_id, workflow_id, workflow, workflow_execution_id], + + future = self.executor.submit( + self._run_workflow, + tenant_id, + workflow_id, + workflow_obj, + workflow_execution_id, ) - thread.start() - self.threads.append(thread) + self.futures.add(future) + future.add_done_callback(lambda f: self.futures.remove(f)) def _run_workflow( self, @@ -120,23 +174,52 @@ def _run_workflow( event_context=None, ): if READ_ONLY_MODE: - # This is because sometimes workflows takes 0 seconds and the executions chart is not updated properly. self.logger.debug("Sleeping for 3 seconds in favor of read only mode") time.sleep(3) + self.logger.info(f"Running workflow {workflow.workflow_id}...") try: + # Increment running workflows counter + workflows_running.labels(tenant_id=tenant_id).inc() + + # Track execution + # Shahar: currently incident doesn't have trigger so we will workaround it + if isinstance(event_context, AlertDto): + workflow_executions_total.labels( + tenant_id=tenant_id, + workflow_id=workflow_id, + trigger_type=event_context.trigger if event_context else "interval", + ).inc() + else: + # TODO: add trigger to incident + workflow_executions_total.labels( + tenant_id=tenant_id, + workflow_id=workflow_id, + trigger_type="incident", + ).inc() + + # Run the workflow if isinstance(event_context, AlertDto): - # set the event context, e.g. the event that triggered the workflow workflow.context_manager.set_event_context(event_context) else: - # set the incident context, e.g. the incident that triggered the workflow workflow.context_manager.set_incident_context(event_context) errors, _ = self.workflow_manager._run_workflow( workflow, workflow_execution_id ) except Exception as e: + # Track error metrics + workflow_execution_errors_total.labels( + tenant_id=tenant_id, + workflow_id=workflow_id, + error_type=type(e).__name__, + ).inc() + + workflow_execution_status.labels( + tenant_id=tenant_id, workflow_id=workflow_id, status="error" + ).inc() + self.logger.exception(f"Failed to run workflow {workflow.workflow_id}...") self._finish_workflow_execution( tenant_id=tenant_id, @@ -146,6 +229,10 @@ def _run_workflow( error=str(e), ) return + finally: + # Decrement running workflows counter + workflows_running.labels(tenant_id=tenant_id).dec() + self._update_queue_metrics() if any(errors): self.logger.info(msg=f"Workflow {workflow.workflow_id} ran with errors") @@ -164,10 +251,10 @@ def _run_workflow( status=WorkflowStatus.SUCCESS, error=None, ) + self.logger.info(f"Workflow {workflow.workflow_id} ran") def handle_workflow_test(self, workflow, tenant_id, triggered_by_user): - workflow_execution_id = self._get_unique_execution_number() self.logger.info( @@ -195,31 +282,35 @@ def run_workflow_wrapper( print(f"Exception in workflow: {e}") result_queue.put((str(e), None)) - thread = threading.Thread( - target=run_workflow_wrapper, - args=[ - self.workflow_manager._run_workflow, - workflow, - workflow_execution_id, - True, - result_queue, - ], + future = self.executor.submit( + run_workflow_wrapper, + self.workflow_manager._run_workflow, + workflow, + workflow_execution_id, + True, + result_queue, ) - thread.start() - thread.join() + future.result() # Wait for completion errors, results = result_queue.get() - self.logger.info( - f"Workflow {workflow.workflow_id} ran", - extra={"errors": errors, "results": results}, - ) - status = "success" error = None if any(errors): error = "\n".join(str(e) for e in errors) status = "error" + self.logger.info( + "Workflow test complete", + extra={ + "workflow_id": workflow.workflow_id, + "workflow_execution_id": workflow_execution_id, + "tenant_id": tenant_id, + "status": status, + "error": error, + "results": results, + }, + ) + return { "workflow_execution_id": workflow_execution_id, "status": status, @@ -320,6 +411,10 @@ def _handle_event_workflows(self): workflow = workflow_to_run.get("workflow") workflow_id = workflow_to_run.get("workflow_id") tenant_id = workflow_to_run.get("tenant_id") + # Update queue size metrics + workflow_queue_size.labels(tenant_id=tenant_id).set( + len(self.workflows_to_run) + ) workflow_execution_id = workflow_to_run.get("workflow_execution_id") if not workflow: self.logger.info("Loading workflow") @@ -499,18 +594,30 @@ def _handle_event_workflows(self): ) continue # Last, run the workflow - thread = threading.Thread( - target=self._run_workflow, - args=[tenant_id, workflow_id, workflow, workflow_execution_id, event], + future = self.executor.submit( + self._run_workflow, + tenant_id, + workflow_id, + workflow, + workflow_execution_id, + event, ) - thread.start() - self.threads.append(thread) + self.futures.add(future) + future.add_done_callback(lambda f: self.futures.remove(f)) + + self.logger.info( + "Event workflows handled", + extra={"current_number_of_workflows": len(self.futures)}, + ) def _start(self): self.logger.info("Starting workflows scheduler") while not self._stop: # get all workflows that should run now - self.logger.debug("Getting workflows that should run...") + self.logger.info( + "Starting workflow scheduler iteration", + extra={"current_number_of_workflows": len(self.futures)}, + ) try: self._handle_interval_workflows() self._handle_event_workflows() @@ -523,55 +630,41 @@ def _start(self): time.sleep(1) self.logger.info("Workflows scheduler stopped") - def run_workflows(self, workflows: typing.List[Workflow]): - for workflow in workflows: - thread = threading.Thread( - target=self._run_workflows_with_interval, - args=[workflow], - daemon=True, - ) - thread.start() - self.threads.append(thread) - # as long as the stop flag is not set, sleep - while not self._stop: - time.sleep(1) - def stop(self): self.logger.info("Stopping scheduled workflows") self._stop = True - # Now wait for the threads to finish - for thread in self.threads: - thread.join() - self.logger.info("Scheduled workflows stopped") - def _run_workflows_with_interval( - self, - workflow: Workflow, - ): - """Simple scheduling of workflows with interval + # Wait for scheduler to stop first + if self.scheduler_future: + try: + self.scheduler_future.result( + timeout=5 + ) # Add timeout to prevent hanging + except Exception: + self.logger.exception("Error waiting for scheduler to stop") - TODO: Use https://github.com/agronholm/apscheduler + # Cancel all running workflows with timeout + for future in list(self.futures): # Create a copy of futures set + try: + self.logger.info("Cancelling future") + future.cancel() + future.result(timeout=1) # Add timeout + self.logger.info("Future cancelled") + except Exception: + self.logger.exception("Error cancelling future") - Args: - workflow (Workflow): The workflow to run. - """ - while True and not self._stop: - self.logger.info(f"Running workflow {workflow.workflow_id}...") + # Shutdown the executor with timeout + if self.executor: try: - self.workflow_manager._run_workflow(workflow, uuid.uuid4()) + self.logger.info("Shutting down executor") + self.executor.shutdown(wait=True, cancel_futures=True) + self.executor = None + self.logger.info("Executor shut down") except Exception: - self.logger.exception( - f"Failed to run workflow {workflow.workflow_id}..." - ) - self.logger.info(f"Workflow {workflow.workflow_id} ran") - if workflow.workflow_interval > 0: - self.logger.info( - f"Sleeping for {workflow.workflow_interval} seconds..." - ) - time.sleep(workflow.workflow_interval) - else: - self.logger.info("Workflow will not run again") - break + self.logger.exception("Error shutting down executor") + + self.futures.clear() + self.logger.info("Scheduled workflows stopped") def _finish_workflow_execution( self, diff --git a/poetry.lock b/poetry.lock index 79d159014..56d053477 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1338,22 +1338,23 @@ testing = ["hatch", "pre-commit", "pytest", "tox"] [[package]] name = "fastapi" -version = "0.109.2" +version = "0.115.6" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"}, - {file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"}, + {file = "fastapi-0.115.6-py3-none-any.whl", hash = "sha256:e9240b29e36fa8f4bb7290316988e90c381e5092e0cbe84e7818cc3713bcf305"}, + {file = "fastapi-0.115.6.tar.gz", hash = "sha256:9ec46f7addc14ea472958a96aae5b5de65f39721a46aaf5705c480d9a8b76654"}, ] [package.dependencies] pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.36.3,<0.37.0" +starlette = ">=0.40.0,<0.42.0" typing-extensions = ">=4.8.0" [package.extras] -all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "filelock" @@ -2738,7 +2739,7 @@ name = "ndg-httpsclient" version = "0.5.1" description = "Provides enhanced HTTPS support for httplib and urllib2 using PyOpenSSL" optional = false -python-versions = ">=2.7,<3.0.dev0 || >=3.4.dev0" +python-versions = ">=2.7,<3.0.0 || >=3.4.0" files = [ {file = "ndg_httpsclient-0.5.1-py2-none-any.whl", hash = "sha256:d2c7225f6a1c6cf698af4ebc962da70178a99bcde24ee6d1961c4f3338130d57"}, {file = "ndg_httpsclient-0.5.1-py3-none-any.whl", hash = "sha256:dd174c11d971b6244a891f7be2b32ca9853d3797a72edb34fa5d7b07d8fff7d4"}, @@ -4981,13 +4982,13 @@ files = [ [[package]] name = "starlette" -version = "0.36.3" +version = "0.41.3" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"}, - {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"}, + {file = "starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7"}, + {file = "starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835"}, ] [package.dependencies] @@ -5463,4 +5464,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "d1ecb84ec2278190d29b2131ef67b077971af74f076c0b4055c475073f36ad10" +content-hash = "089d3d061da28029a73cbe1b566b3d6ef531145407b322934821b1003ff9681d" diff --git a/pyproject.toml b/pyproject.toml index 0fef9c044..c69b12a7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "keep" -version = "0.33.0" +version = "0.33.1" description = "Alerting. for developers, by developers." authors = ["Keep Alerting LTD"] packages = [{include = "keep"}] @@ -24,7 +24,7 @@ python-json-logger = "^2.0.6" boto3 = "^1.26.72" validators = "^0.20.0" python-telegram-bot = "^20.1" -fastapi = "^0.109.1" +fastapi = "^0.115.6" uvicorn = "0.32.1" opsgenie-sdk = "^2.1.5" starlette-context = "^0.3.6" diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py index c9ae1cc53..0272f30cd 100644 --- a/tests/test_workflow_execution.py +++ b/tests/test_workflow_execution.py @@ -79,15 +79,26 @@ @pytest.fixture(scope="module") def workflow_manager(): """ - Fixture to create and manage a WorkflowManager instance for the duration of the module. - It starts the manager asynchronously and stops it after all tests are completed. + Fixture to create and manage a WorkflowManager instance. """ - manager = WorkflowManager.get_instance() - asyncio.run(manager.start()) - while not manager.started: - time.sleep(0.1) - yield manager - manager.stop() + manager = None + try: + from keep.workflowmanager.workflowscheduler import WorkflowScheduler + + scheduler = WorkflowScheduler(None) + manager = WorkflowManager.get_instance() + scheduler.workflow_manager = manager + manager.scheduler = scheduler + asyncio.run(manager.start()) + yield manager + finally: + if manager: + try: + manager.stop() + # Give some time for threads to clean up + time.sleep(1) + except Exception as e: + print(f"Error stopping workflow manager: {e}") @pytest.fixture @@ -690,7 +701,11 @@ def test_workflow_execution_with_disabled_workflow( count = 0 while ( - enabled_workflow_execution is None and disabled_workflow_execution is None + ( + enabled_workflow_execution is None + or enabled_workflow_execution.status == "in_progress" + ) + and disabled_workflow_execution is None ) and count < 30: enabled_workflow_execution = get_last_workflow_execution_by_workflow_id( SINGLE_TENANT_UUID, enabled_id diff --git a/tests/test_workflowmanager.py b/tests/test_workflowmanager.py index 2b4868146..eb43abe15 100644 --- a/tests/test_workflowmanager.py +++ b/tests/test_workflowmanager.py @@ -1,16 +1,17 @@ -import pytest +import queue +from pathlib import Path from unittest.mock import Mock, patch + +import pytest from fastapi import HTTPException -import threading -import queue + +from keep.parser.parser import Parser # Assuming WorkflowParser is the class containing the get_workflow_from_dict method from keep.workflowmanager.workflow import Workflow -from keep.workflowmanager.workflowstore import WorkflowStore from keep.workflowmanager.workflowmanager import WorkflowManager from keep.workflowmanager.workflowscheduler import WorkflowScheduler -from keep.parser.parser import Parser -from pathlib import Path +from keep.workflowmanager.workflowstore import WorkflowStore path_to_test_resources = Path(__file__).parent / "workflows" @@ -109,29 +110,27 @@ def test_handle_workflow_test(): tenant_id = "test_tenant" triggered_by_user = "test_user" - with patch.object(threading, "Thread", wraps=threading.Thread) as mock_thread: - with patch.object(queue, "Queue", wraps=queue.Queue) as mock_queue: - result = workflow_scheduler.handle_workflow_test( - workflow=mock_workflow, - tenant_id=tenant_id, - triggered_by_user=triggered_by_user, - ) + with patch.object(queue, "Queue", wraps=queue.Queue) as mock_queue: + result = workflow_scheduler.handle_workflow_test( + workflow=mock_workflow, + tenant_id=tenant_id, + triggered_by_user=triggered_by_user, + ) - mock_workflow_manager._run_workflow.assert_called_once_with( - mock_workflow, 123, True - ) + mock_workflow_manager._run_workflow.assert_called_once_with( + mock_workflow, 123, True + ) - assert mock_thread.call_count == 1 - assert mock_queue.call_count == 1 + assert mock_queue.call_count == 1 - expected_result = { - "workflow_execution_id": 123, - "status": "success", - "error": None, - "results": {"result": "value1"}, - } - assert result == expected_result - assert mock_logger.info.call_count == 2 + expected_result = { + "workflow_execution_id": 123, + "status": "success", + "error": None, + "results": {"result": "value1"}, + } + assert result == expected_result + assert mock_logger.info.call_count == 2 def test_handle_workflow_test_with_error(): @@ -152,26 +151,24 @@ def test_handle_workflow_test_with_error(): tenant_id = "test_tenant" triggered_by_user = "test_user" - with patch.object(threading, "Thread", wraps=threading.Thread) as mock_thread: - with patch.object(queue, "Queue", wraps=queue.Queue) as mock_queue: - result = workflow_scheduler.handle_workflow_test( - workflow=mock_workflow, - tenant_id=tenant_id, - triggered_by_user=triggered_by_user, - ) - - mock_workflow_manager._run_workflow.assert_called_once_with( - mock_workflow, 123, True - ) - - assert mock_thread.call_count == 1 - assert mock_queue.call_count == 1 - - expected_result = { - "workflow_execution_id": 123, - "status": "error", - "error": "Error occurred", - "results": None, - } - assert result == expected_result - assert mock_logger.info.call_count == 2 + with patch.object(queue, "Queue", wraps=queue.Queue) as mock_queue: + result = workflow_scheduler.handle_workflow_test( + workflow=mock_workflow, + tenant_id=tenant_id, + triggered_by_user=triggered_by_user, + ) + + mock_workflow_manager._run_workflow.assert_called_once_with( + mock_workflow, 123, True + ) + + assert mock_queue.call_count == 1 + + expected_result = { + "workflow_execution_id": 123, + "status": "error", + "error": "Error occurred", + "results": None, + } + assert result == expected_result + assert mock_logger.info.call_count == 2