Skip to content

Commit

Permalink
watch mode changes (squash this commit when merging the PR!)
Browse files Browse the repository at this point in the history
  • Loading branch information
tumidi committed Aug 1, 2024
1 parent d9b6c14 commit da764d7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 67 deletions.
84 changes: 42 additions & 42 deletions questionpy_sdk/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager
from functools import cached_property
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Self
Expand All @@ -30,7 +29,7 @@
from questionpy_server.worker.runtime.package_location import DirPackageLocation

if TYPE_CHECKING:
from watchdog.observers.api import BaseObserver
from watchdog.observers.api import BaseObserver, ObservedWatch

log = logging.getLogger("questionpy-sdk:watcher")

Expand Down Expand Up @@ -99,61 +98,63 @@ def __init__(self, source_path: Path, pkg_location: DirPackageLocation, host: st
self._event_handler: _EventHandler | None
self._observer: BaseObserver | None
self._webserver: WebServer | None = None
self._on_change_condition = asyncio.Condition()
self._on_change_event = asyncio.Event()
self._watch: ObservedWatch | None = None

async def __aenter__(self) -> Self:
self._loop = asyncio.get_running_loop()

self.start()
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
self.stop_watching()

def start(self) -> None:
self.start_watching()
log.info("Watching '%s' for changes...", self._source_path)

def start_watching(self) -> None:
log.debug("Starting file watching...")

self._event_handler = _EventHandler(self._loop, self._notify, self._source_path)
self._observer = Observer()
self._event_handler.start()
self._observer.schedule(self._event_handler, self._source_path, recursive=True)
self._observer.start()
log.info("Watching '%s' for changes...", self._source_path)

def stop_watching(self) -> None:
log.debug("Stopping file watching...")
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
if self._observer and self._observer.is_alive():
self._observer.stop()
if self._event_handler:
self._event_handler.stop()
if self._webserver:
await self._webserver.stop_server()

def _schedule(self) -> None:
if self._observer and self._watch is None:
log.debug("Starting file watching...")
self._watch = self._observer.schedule(self._event_handler, self._source_path, recursive=True)

def _unschedule(self) -> None:
if self._watch and self._observer:
log.debug("Stopping file watching...")
self._observer.unschedule(self._watch)
self._watch = None

async def _notify(self) -> None:
async with self._on_change_condition:
self._on_change_condition.notify()
self._on_change_event.set()

async def run_forever(self) -> None:
try:
await self._start_webserver()
self._webserver = WebServer(self._pkg_location, host=self._host, port=self._port)
await self._webserver.start_server()
except Exception:
log.exception("Failed to start webserver. The exception was:")
# When user messed up the their package on initial run, we just bail out.
return

self._schedule()

while True:
async with self._on_change_condition:
# Wait for file changes.
await self._on_change_condition.wait()
await self._on_change_event.wait()

# Try to rebuild package and restart web server which might fail.
self._unschedule()
await self._rebuild_and_restart()
self._schedule()

# Try to rebuild package and restart web server which might fail.
self.stop_watching()
await self._rebuild_and_restart()
self.start_watching()
self._on_change_event.clear()

async def _rebuild_and_restart(self) -> None:
if self._webserver:
Expand All @@ -166,9 +167,17 @@ async def _rebuild_and_restart(self) -> None:
log.exception("Failed to stop web server. The exception was:")
raise # Should not happen, thus we're propagating.

# Determine module name.
try:
pkg_config = PackageSource(self._source_path).config
pkg_module_name = f"{pkg_config.namespace}.{pkg_config.short_name}"
except Exception:
log.exception("Failed to build package. The exception was:")
return

# Remove package modules from Python.
for name in sys.modules.copy():
if name.startswith(self._pkg_module_name):
if name == pkg_module_name or name.startswith(f"{pkg_module_name}."):
del sys.modules[name]

# Build package.
Expand All @@ -185,12 +194,3 @@ async def _rebuild_and_restart(self) -> None:
await self._webserver.start_server()
except Exception:
log.exception("Failed to start web server. The exception was:")

async def _start_webserver(self) -> None:
self._webserver = WebServer(self._pkg_location, host=self._host, port=self._port)
await self._webserver.start_server()

@cached_property
def _pkg_module_name(self) -> str:
config = PackageSource(self._source_path).config
return f"{config.namespace}.{config.short_name}"
61 changes: 36 additions & 25 deletions questionpy_sdk/webserver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,45 +68,30 @@ def __init__(
self._host = host
self._port = port

self._web_app: web.Application | None = None
self._runner: web.AppRunner | None = None
self.worker_pool: WorkerPool = WorkerPool(1, 500 * MiB, worker_type=ThreadWorker)

async def start_server(self) -> None:
self.create_webapp()
self._runner = web.AppRunner(self.web_app)
if self._web_app:
msg = "Web app is already running"
raise RuntimeError(msg)

self._web_app = self._create_webapp()
self._runner = web.AppRunner(self._web_app)
await self._runner.setup()
await web.TCPSite(self._runner, self._host, self._port).start()

async def stop_server(self) -> None:
if self._runner:
await self._runner.cleanup()
self._web_app = None
self._runner = None

async def run_forever(self) -> None:
await self.start_server()
await asyncio.Event().wait() # run forever

def create_webapp(self) -> web.Application:
# We import here, so we don't have to work around circular imports.
from questionpy_sdk.webserver.routes.attempt import routes as attempt_routes # noqa: PLC0415
from questionpy_sdk.webserver.routes.options import routes as options_routes # noqa: PLC0415
from questionpy_sdk.webserver.routes.worker import routes as worker_routes # noqa: PLC0415

self.web_app = web.Application()
self.web_app[SDK_WEBSERVER_APP_KEY] = self

self.web_app.add_routes(attempt_routes)
self.web_app.add_routes(options_routes)
self.web_app.add_routes(worker_routes)
self.web_app.router.add_static("/static", Path(__file__).parent / "static")

self.web_app.on_startup.append(_extract_manifest)
self.web_app.middlewares.append(_invalid_question_state_middleware)

jinja2_extensions = ["jinja2.ext.do"]
aiohttp_jinja2.setup(self.web_app, loader=PackageLoader(__package__), extensions=jinja2_extensions)

return self.web_app

def read_state_file(self, filename: StateFilename) -> str | None:
try:
return (self._package_state_dir / filename).read_text()
Expand All @@ -124,9 +109,35 @@ def delete_state_files(self, filename_1: StateFilename, *filenames: StateFilenam
# Remove package state dir if it's now empty.
self._package_state_dir.rmdir()

def _create_webapp(self) -> web.Application:
# We import here, so we don't have to work around circular imports.
from questionpy_sdk.webserver.routes.attempt import routes as attempt_routes # noqa: PLC0415
from questionpy_sdk.webserver.routes.options import routes as options_routes # noqa: PLC0415
from questionpy_sdk.webserver.routes.worker import routes as worker_routes # noqa: PLC0415

app = web.Application()
app[SDK_WEBSERVER_APP_KEY] = self

app.add_routes(attempt_routes)
app.add_routes(options_routes)
app.add_routes(worker_routes)
app.router.add_static("/static", Path(__file__).parent / "static")

app.on_startup.append(_extract_manifest)
app.middlewares.append(_invalid_question_state_middleware)

jinja2_extensions = ["jinja2.ext.do"]
aiohttp_jinja2.setup(app, loader=PackageLoader(__package__), extensions=jinja2_extensions)

return app

@cached_property
def _package_state_dir(self) -> Path:
manifest = self.web_app[MANIFEST_APP_KEY]
if self._web_app is None:
msg = "Web app not initialized"
raise RuntimeError(msg)

manifest = self._web_app[MANIFEST_APP_KEY]
return self._state_storage_root / f"{manifest.namespace}-{manifest.short_name}-{manifest.version}"


Expand Down

0 comments on commit da764d7

Please sign in to comment.