diff --git a/openadapt/a11y/__init__.py b/openadapt/a11y/__init__.py new file mode 100644 index 000000000..56cb8e23f --- /dev/null +++ b/openadapt/a11y/__init__.py @@ -0,0 +1,49 @@ +"""This module provides platform-specific implementations for window and element + interactions using accessibility APIs. It abstracts the platform differences + and provides a unified interface for retrieving the active window, finding + display elements, and getting element values. +""" + +import sys + +from loguru import logger + +if sys.platform == "darwin": + from . import _macos as impl + + role = "AXStaticText" +elif sys.platform in ("win32", "linux"): + from . import _windows as impl + + role = "Text" +else: + raise Exception(f"Unsupported platform: {sys.platform}") + + +def get_active_window(): + """Get the active window object. + + Returns: + The active window object. + """ + try: + return impl.get_active_window() + except Exception as exc: + logger.warning(f"{exc=}") + return None + + +def get_element_value(active_window, role=role): + """Find the display of active_window. + + Args: + active_window: The parent window to search within. + + Returns: + The found active_window. + """ + try: + return impl.get_element_value(active_window, role) + except Exception as exc: + logger.warning(f"{exc=}") + return None diff --git a/openadapt/a11y/_macos.py b/openadapt/a11y/_macos.py new file mode 100644 index 000000000..5a6bade71 --- /dev/null +++ b/openadapt/a11y/_macos.py @@ -0,0 +1,61 @@ +import AppKit +import ApplicationServices + + +def get_attribute(element, attribute): + result, value = ApplicationServices.AXUIElementCopyAttributeValue( + element, attribute, None + ) + if result == 0: + return value + return None + + +def find_element_by_attribute(element, attribute, value): + if get_attribute(element, attribute) == value: + return element + children = get_attribute(element, ApplicationServices.kAXChildrenAttribute) or [] + for child in children: + found = find_element_by_attribute(child, attribute, value) + if found: + return found + return None + + +def get_active_window(): + """Get the active window object. + + Returns: + AXUIElement: The active window object. + """ + workspace = AppKit.NSWorkspace.sharedWorkspace() + active_app = workspace.frontmostApplication() + app_element = ApplicationServices.AXUIElementCreateApplication( + active_app.processIdentifier() + ) + + error_code, focused_window = ApplicationServices.AXUIElementCopyAttributeValue( + app_element, ApplicationServices.kAXFocusedWindowAttribute, None + ) + if error_code: + raise Exception("Could not get the active window.") + return focused_window + + +def get_element_value(element, role="AXStaticText"): + """Get the value of a specific element . + + Args: + element: The AXUIElement to search within. + + Returns: + str: The value of the element, or an error message if not found. + """ + target_element = find_element_by_attribute( + element, ApplicationServices.kAXRoleAttribute, role + ) + if not target_element: + return f"AXStaticText element not found." + + value = get_attribute(target_element, ApplicationServices.kAXValueAttribute) + return value if value else f"No value for AXStaticText element." diff --git a/openadapt/a11y/_windows.py b/openadapt/a11y/_windows.py new file mode 100644 index 000000000..974e5f3ec --- /dev/null +++ b/openadapt/a11y/_windows.py @@ -0,0 +1,44 @@ +from loguru import logger +import pywinauto +import re + + +def get_active_window() -> pywinauto.application.WindowSpecification: + """Get the active window object. + + Returns: + pywinauto.application.WindowSpecification: The active window object. + """ + app = pywinauto.application.Application(backend="uia").connect(active_only=True) + window = app.top_window() + return window.wrapper_object() + + +def get_element_value(active_window, role="Text"): + """Find the display element. + + Args: + active_window: The parent window to search within. + role (str): The role of the element to search for. + + Returns: + The found display element value. + + Raises: + ValueError: If the element is not found. + """ + try: + elements = active_window.descendants() # Retrieve all descendants + for elem in elements: + if ( + elem.element_info.control_type == role + and elem.element_info.name.startswith("Display is") + ): + # Extract the number from the element's name + match = re.search(r"[-+]?\d*\.?\d+", elem.element_info.name) + if match: + return str(match.group()) + raise ValueError("Display element not found") + except Exception as exc: + logger.warning(f"Error in get_element_value: {exc}") + return None diff --git a/openadapt/app/dashboard/api/recordings.py b/openadapt/app/dashboard/api/recordings.py index 32b13debf..74a8c8252 100644 --- a/openadapt/app/dashboard/api/recordings.py +++ b/openadapt/app/dashboard/api/recordings.py @@ -36,7 +36,7 @@ def attach_routes(self) -> APIRouter: def get_recordings() -> dict[str, list[Recording]]: """Get all recordings.""" session = crud.get_new_session(read_only=True) - recordings = crud.get_all_recordings(session) + recordings = crud.get_recordings(session) return {"recordings": recordings} @staticmethod diff --git a/openadapt/app/tray.py b/openadapt/app/tray.py index 0a6381192..17f7ca45e 100644 --- a/openadapt/app/tray.py +++ b/openadapt/app/tray.py @@ -463,7 +463,7 @@ def populate_menu(self, menu: QMenu, action: Callable, action_type: str) -> None action_type (str): The type of action to perform ["visualize", "replay"] """ session = crud.get_new_session(read_only=True) - recordings = crud.get_all_recordings(session) + recordings = crud.get_recordings(session) self.recording_actions[action_type] = [] diff --git a/openadapt/config.py b/openadapt/config.py index 742aeaa88..6215397dd 100644 --- a/openadapt/config.py +++ b/openadapt/config.py @@ -1,6 +1,5 @@ """Configuration module for OpenAdapt.""" - from enum import Enum from typing import Any, ClassVar, Type, Union import json @@ -33,6 +32,7 @@ CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute() VIDEO_DIR_PATH = DATA_DIR_PATH / "videos" DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock" +DB_FILE_PATH = (DATA_DIR_PATH / "openadapt.db").absolute() STOP_STRS = [ "oa.stop", @@ -124,7 +124,8 @@ class SegmentationAdapter(str, Enum): # Database DB_ECHO: bool = False - DB_URL: ClassVar[str] = f"sqlite:///{(DATA_DIR_PATH / 'openadapt.db').absolute()}" + DB_FILE_PATH: str = str(DB_FILE_PATH) + DB_URL: ClassVar[str] = f"sqlite:///{DB_FILE_PATH}" # Error reporting ERROR_REPORTING_ENABLED: bool = True @@ -428,11 +429,13 @@ def show_alert() -> None: """Show an alert to the user.""" msg = QMessageBox() msg.setIcon(QMessageBox.Warning) - msg.setText(""" + msg.setText( + """ An error has occurred. The development team has been notified. Please join the discord server to get help or send an email to help@openadapt.ai - """) + """ + ) discord_button = QPushButton("Join the discord server") discord_button.clicked.connect( lambda: webbrowser.open("https://discord.gg/yF527cQbDG") diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 1e6bc2649..9728681ae 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -281,22 +281,27 @@ def delete_recording(session: SaSession, recording: Recording) -> None: delete_video_file(recording_timestamp) -def get_all_recordings(session: SaSession) -> list[Recording]: +def get_recordings(session: SaSession, max_rows=None) -> list[Recording]: """Get all recordings. Args: session (sa.orm.Session): The database session. + max_rows: The number of recordings to return, starting from the most recent. + Defaults to all if max_rows is not specified. Returns: list[Recording]: A list of all original recordings. """ - return ( + query = ( session.query(Recording) .filter(Recording.original_recording_id == None) # noqa: E711 .order_by(sa.desc(Recording.timestamp)) - .all() ) + if max_rows: + query = query.limit(max_rows) + return query.all() + def get_all_scrubbed_recordings( session: SaSession, @@ -352,6 +357,21 @@ def get_recording(session: SaSession, timestamp: float) -> Recording: return session.query(Recording).filter(Recording.timestamp == timestamp).first() +def get_recordings_by_desc(session: SaSession, description_str: str) -> list[Recording]: + """Get recordings by task description. + Args: + session (sa.orm.Session): The database session. + task_description (str): The task description to search for. + Returns: + list[Recording]: A list of recordings whose task descriptions contain the given string. + """ + return ( + session.query(Recording) + .filter(Recording.task_description.contains(description_str)) + .all() + ) + + BaseModelType = TypeVar("BaseModelType") diff --git a/openadapt/scripts/generate_db_fixtures.py b/openadapt/scripts/generate_db_fixtures.py new file mode 100644 index 000000000..e3a4da8b9 --- /dev/null +++ b/openadapt/scripts/generate_db_fixtures.py @@ -0,0 +1,151 @@ +from sqlalchemy import create_engine, inspect +from openadapt.db.db import Base +from openadapt.config import PARENT_DIR_PATH, RECORDING_DIR_PATH +import openadapt.db.crud as crud +from loguru import logger + + +def get_session(): + """ + Establishes a database connection and returns a session and engine. + + Returns: + tuple: A tuple containing the SQLAlchemy session and engine. + """ + db_url = RECORDING_DIR_PATH / "recording.db" + logger.info(f"Database URL: {db_url}") + engine = create_engine(f"sqlite:///{db_url}") + Base.metadata.create_all(bind=engine) + session = crud.get_new_session(read_only=True) + logger.info("Database connection established.") + return session, engine + + +def check_tables_exist(engine): + """ + Checks if the expected tables exist in the database. + + Args: + engine: SQLAlchemy engine object. + + Returns: + list: A list of table names in the database. + """ + inspector = inspect(engine) + tables = inspector.get_table_names() + expected_tables = [ + "recording", + "action_event", + "screenshot", + "window_event", + "performance_stat", + "memory_stat", + ] + for table_name in expected_tables: + table_exists = table_name in tables + logger.info(f"{table_name=} {table_exists=}") + return tables + + +def fetch_data(session): + """ + Fetches the most recent recordings and related data from the database. + + Args: + session: SQLAlchemy session object. + + Returns: + dict: A dictionary containing fetched data. + """ + # get the most recent three recordings + recordings = crud.get_recordings(session, max_rows=3) + + action_events = [] + screenshots = [] + window_events = [] + performance_stats = [] + memory_stats = [] + + for recording in recordings: + action_events.extend(crud.get_action_events(session, recording)) + screenshots.extend(crud.get_screenshots(session, recording)) + window_events.extend(crud.get_window_events(session, recording)) + performance_stats.extend(crud.get_perf_stats(session, recording)) + memory_stats.extend(crud.get_memory_stats(session, recording)) + + data = { + "recordings": recordings, + "action_events": action_events, + "screenshots": screenshots, + "window_events": window_events, + "performance_stats": performance_stats, + "memory_stats": memory_stats, + } + + # Debug prints to verify data fetching + logger.info(f"Recordings: {len(data['recordings'])} found.") + logger.info(f"Action Events: {len(data['action_events'])} found.") + logger.info(f"Screenshots: {len(data['screenshots'])} found.") + logger.info(f"Window Events: {len(data['window_events'])} found.") + logger.info(f"Performance Stats: {len(data['performance_stats'])} found.") + logger.info(f"Memory Stats: {len(data['memory_stats'])} found.") + + return data + + +def format_sql_insert(table_name, rows): + """ + Formats SQL insert statements for a given table and rows. + + Args: + table_name (str): The name of the table. + rows (list): A list of SQLAlchemy ORM objects representing the rows. + + Returns: + str: A string containing the SQL insert statements. + """ + if not rows: + return "" + + columns = rows[0].__table__.columns.keys() + sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES\n" + values = [] + + for row in rows: + row_values = [getattr(row, col) for col in columns] + row_values = [ + f"'{value}'" if isinstance(value, str) else str(value) + for value in row_values + ] + values.append(f"({', '.join(row_values)})") + + sql += ",\n".join(values) + ";\n" + return sql + + +def dump_to_fixtures(filepath): + """ + Dumps the fetched data into an SQL file. + + Args: + filepath (str): The path to the SQL file. + """ + session, engine = get_session() + check_tables_exist(engine) + rows_by_table_name = fetch_data(session) + + for table_name, rows in rows_by_table_name.items(): + if not rows: + logger.warning(f"No rows for {table_name=}") + continue + with open(filepath, "a", encoding="utf-8") as file: + logger.info(f"Writing {len(rows)=} to {filepath=} for {table_name=}") + file.write(f"-- Insert sample rows for {table_name}\n") + file.write(format_sql_insert(table_name, rows)) + + +if __name__ == "__main__": + + fixtures_path = PARENT_DIR_PATH / "tests/assets/fixtures.sql" + + dump_to_fixtures(fixtures_path) diff --git a/openadapt/scripts/reset_db.py b/openadapt/scripts/reset_db.py index 8bba91be2..a0e39fb08 100644 --- a/openadapt/scripts/reset_db.py +++ b/openadapt/scripts/reset_db.py @@ -14,8 +14,8 @@ def reset_db() -> None: """Clears the database by removing the db file and running a db migration.""" - if os.path.exists(config.DB_FPATH): - os.remove(config.DB_FPATH) + if os.path.exists(config.DB_FILE_PATH): + os.remove(config.DB_FILE_PATH) # Prevents duplicate logging of config values by piping stderr # and filtering the output. diff --git a/tests/openadapt/test_performance.py b/tests/openadapt/test_performance.py new file mode 100644 index 000000000..ca58abccd --- /dev/null +++ b/tests/openadapt/test_performance.py @@ -0,0 +1,54 @@ +import pytest +from loguru import logger +from openadapt.db.crud import ( + get_recordings_by_desc, + get_new_session, +) +from openadapt.replay import replay +from openadapt.a11y import ( + get_active_window, + get_element_value, +) + + +# parametrized tests +@pytest.mark.parametrize( + "task_description, replay_strategy, expected_value, instructions", + [ + ("test_calculator", "VisualReplayStrategy", "6", " "), + ("test_calculator", "VisualReplayStrategy", "8", "calculate 9-8+7"), + # ("test_spreadsheet", "NaiveReplayStrategy"), + # ("test_powerpoint", "NaiveReplayStrategy") + ], +) +def test_replay(task_description, replay_strategy, expected_value, instructions): + # Get recordings which contain the string "test_calculator" + session = get_new_session(read_only=True) + recordings = get_recordings_by_desc(session, task_description) + + assert ( + len(recordings) > 0 + ), f"No recordings found with task description: {task_description}" + recording = recordings[0] + + result = replay( + strategy_name=replay_strategy, + recording=recording, + instructions=instructions, + ) + assert result is True, f"Replay failed for recording: {recording.id}" + + active_window = get_active_window() + element_value = get_element_value(active_window) + logger.info(element_value) + + assert ( + element_value == expected_value + ), f"Value mismatch: expected '{expected_value}', got '{element_value}'" + + result_message = f"Value match: '{element_value}' == '{expected_value}'" + logger.info(result_message) + + +if __name__ == "__main__": + pytest.main()