diff --git a/api/v1/__init__.py b/api/v1/__init__.py deleted file mode 100644 index 4c23146..0000000 --- a/api/v1/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from fastapi import APIRouter - -import api.v1.countries as countries -import api.v1.node as node -import api.v1.photos as photos -import api.v1.tile as tile - -router = APIRouter(prefix='/api/v1') -router.include_router(countries.router) -router.include_router(node.router) -router.include_router(photos.router) -router.include_router(tile.router) diff --git a/api/v1/countries.py b/api/v1/countries.py index 2725c33..851ceee 100644 --- a/api/v1/countries.py +++ b/api/v1/countries.py @@ -2,20 +2,23 @@ from typing import Annotated from anyio import create_task_group -from fastapi import APIRouter, Path, Response +from fastapi import APIRouter, Path from sentry_sdk import start_span from shapely.geometry import mapping from middlewares.cache_middleware import configure_cache +from middlewares.skip_serialization import skip_serialization from models.country import Country from states.aed_state import AEDState from states.country_state import CountryState +from utils import simple_point_mapping router = APIRouter(prefix='/countries') @router.get('/names') @configure_cache(timedelta(hours=1), stale=timedelta(days=7)) +@skip_serialization() async def get_names(language: str | None = None): countries = await CountryState.get_all_countries() country_count_map: dict[str, int] = {} @@ -55,23 +58,24 @@ def limit_country_names(names: dict[str, str]): @router.get('/{country_code}.geojson') @configure_cache(timedelta(hours=1), stale=timedelta(seconds=0)) -async def get_geojson( - response: Response, - country_code: Annotated[str, Path(min_length=2, max_length=5)], -): +@skip_serialization( + { + 'Content-Disposition': 'attachment', + 'Content-Type': 'application/geo+json; charset=utf-8', + } +) +async def get_geojson(country_code: Annotated[str, Path(min_length=2, max_length=5)]): if country_code == 'WORLD': aeds = await AEDState.get_all_aeds() else: aeds = await AEDState.get_aeds_by_country_code(country_code) - response.headers['Content-Disposition'] = 'attachment' - response.headers['Content-Type'] = 'application/geo+json; charset=utf-8' return { 'type': 'FeatureCollection', 'features': [ { 'type': 'Feature', - 'geometry': mapping(aed.position), + 'geometry': simple_point_mapping(aed.position), 'properties': { '@osm_type': 'node', '@osm_id': aed.id, diff --git a/api/v1/node.py b/api/v1/node.py index a28feaf..c47f16e 100644 --- a/api/v1/node.py +++ b/api/v1/node.py @@ -8,6 +8,7 @@ from tzfpy import get_tz from middlewares.cache_middleware import configure_cache +from middlewares.skip_serialization import skip_serialization from states.aed_state import AEDState from states.photo_state import PhotoState from utils import get_wikimedia_commons_url @@ -72,6 +73,7 @@ async def _get_image_data(tags: dict[str, str]) -> dict: @router.get('/node/{node_id}') @configure_cache(timedelta(minutes=1), stale=timedelta(minutes=5)) +@skip_serialization() async def get_node(node_id: int): aed = await AEDState.get_aed_by_id(node_id) diff --git a/api/v1/photos.py b/api/v1/photos.py index 309f8d7..d542e36 100644 --- a/api/v1/photos.py +++ b/api/v1/photos.py @@ -51,7 +51,7 @@ async def _fetch_image(url: str) -> tuple[bytes, str]: @router.get('/view/{id}.webp') @configure_cache(timedelta(days=365), stale=timedelta(days=365)) -async def view(id: str) -> FileResponse: +async def view(id: str): info = await PhotoState.get_photo_by_id(id) if info is None: @@ -62,7 +62,7 @@ async def view(id: str) -> FileResponse: @router.get('/proxy/direct/{url_encoded:path}') @configure_cache(timedelta(days=7), stale=timedelta(days=7)) -async def proxy_direct(url_encoded: str) -> FileResponse: +async def proxy_direct(url_encoded: str): url = unquote_plus(url_encoded) file, content_type = await _fetch_image(url) return Response(file, media_type=content_type) @@ -70,7 +70,7 @@ async def proxy_direct(url_encoded: str) -> FileResponse: @router.get('/proxy/wikimedia-commons/{path_encoded:path}') @configure_cache(timedelta(days=7), stale=timedelta(days=7)) -async def proxy_wikimedia_commons(path_encoded: str) -> FileResponse: +async def proxy_wikimedia_commons(path_encoded: str): meta_url = get_wikimedia_commons_url(unquote_plus(path_encoded)) async with get_http_client() as http: @@ -145,12 +145,12 @@ async def upload( @router.post('/report') -async def report(id: Annotated[str, Form()]) -> bool: +async def report(id: Annotated[str, Form()]): return await PhotoReportState.report_by_photo_id(id) @router.get('/report/rss.xml') -async def report_rss(request: Request) -> Response: +async def report_rss(request: Request): fg = FeedGenerator() fg.title('AED Photo Reports') fg.description('This feed contains a list of recent AED photo reports') diff --git a/config.py b/config.py index 9439c04..a2b021c 100644 --- a/config.py +++ b/config.py @@ -12,7 +12,7 @@ from sentry_sdk.integrations.pymongo import PyMongoIntegration NAME = 'openaedmap-backend' -VERSION = '2.7.5' +VERSION = '2.7.6' CREATED_BY = f'{NAME} {VERSION}' WEBSITE = 'https://openaedmap.org' diff --git a/main.py b/main.py index 88f12b5..6e7602b 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,17 @@ +import importlib +import logging +import pathlib from contextlib import asynccontextmanager from datetime import timedelta from anyio import create_task_group -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.base import BaseHTTPMiddleware -import api.v1 as v1 from config import DEFAULT_CACHE_MAX_AGE, DEFAULT_CACHE_STALE, startup_setup from middlewares.cache_middleware import CacheMiddleware -from middlewares.profiler_middleware import profiler_middleware -from middlewares.version_middleware import version_middleware +from middlewares.profiler_middleware import ProfilerMiddleware +from middlewares.version_middleware import VersionMiddleware from orjson_response import CustomORJSONResponse from states.aed_state import AEDState from states.country_state import CountryState @@ -39,9 +40,8 @@ async def lifespan(_): app = FastAPI(lifespan=lifespan, default_response_class=CustomORJSONResponse) -app.include_router(v1.router) -app.add_middleware(BaseHTTPMiddleware, dispatch=profiler_middleware) -app.add_middleware(BaseHTTPMiddleware, dispatch=version_middleware) +app.add_middleware(ProfilerMiddleware) +app.add_middleware(VersionMiddleware) app.add_middleware( CacheMiddleware, max_age=DEFAULT_CACHE_MAX_AGE, @@ -54,3 +54,28 @@ async def lifespan(_): allow_methods=['GET'], max_age=int(timedelta(days=1).total_seconds()), ) + + +def _make_router(path: str, prefix: str) -> APIRouter: + """ + Create a router from all modules in the given path. + """ + router = APIRouter(prefix=prefix) + counter = 0 + + for p in pathlib.Path(path).glob('*.py'): + module_name = p.with_suffix('').as_posix().replace('/', '.') + module = importlib.import_module(module_name) + router_attr = getattr(module, 'router', None) + + if router_attr is not None: + router.include_router(router_attr) + counter += 1 + else: + logging.warning('Missing router in %s', module_name) + + logging.info('Loaded %d routers from %s as %r', counter, path, prefix) + return router + + +app.include_router(_make_router('api/v1', '/api/v1')) diff --git a/middlewares/profiler_middleware.py b/middlewares/profiler_middleware.py index 19f1b2e..7aa2a36 100644 --- a/middlewares/profiler_middleware.py +++ b/middlewares/profiler_middleware.py @@ -1,16 +1,18 @@ from fastapi import Request from pyinstrument import Profiler +from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import HTMLResponse -# https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi -async def profiler_middleware(request: Request, call_next): - profiling = request.query_params.get('profile', False) - if profiling: - profiler = Profiler() - profiler.start() - await call_next(request) - profiler.stop() - return HTMLResponse(profiler.output_html()) - else: - return await call_next(request) +class ProfilerMiddleware(BaseHTTPMiddleware): + # https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi + async def dispatch(self, request: Request, call_next): + profiling = request.query_params.get('profile', False) + if profiling: + profiler = Profiler() + profiler.start() + await call_next(request) + profiler.stop() + return HTMLResponse(profiler.output_html()) + else: + return await call_next(request) diff --git a/middlewares/skip_serialization.py b/middlewares/skip_serialization.py new file mode 100644 index 0000000..5df9d8f --- /dev/null +++ b/middlewares/skip_serialization.py @@ -0,0 +1,20 @@ +import functools +from collections.abc import Mapping + +from fastapi import Response + +from orjson_response import CustomORJSONResponse + + +def skip_serialization(headers: Mapping[str, str] | None = None): + def decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + raw_response = await func(*args, **kwargs) + if isinstance(raw_response, Response): + return raw_response + return CustomORJSONResponse(raw_response, headers=headers) + + return wrapper + + return decorator diff --git a/middlewares/version_middleware.py b/middlewares/version_middleware.py index 718c979..47564e8 100644 --- a/middlewares/version_middleware.py +++ b/middlewares/version_middleware.py @@ -1,9 +1,11 @@ from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware from config import VERSION -async def version_middleware(request: Request, call_next): - response = await call_next(request) - response.headers['X-Version'] = VERSION - return response +class VersionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + response.headers['X-Version'] = VERSION + return response diff --git a/states/country_state.py b/states/country_state.py index 928e6b9..271a501 100644 --- a/states/country_state.py +++ b/states/country_state.py @@ -13,7 +13,7 @@ from osm_countries import get_osm_countries from state_utils import get_state_doc, set_state_doc from transaction import Transaction -from utils import retry_exponential +from utils import retry_exponential, simple_point_mapping from validators.geometry import geometry_validator @@ -164,11 +164,9 @@ async def get_countries_within(cls, bbox_or_pos: BBox | Point) -> Sequence[Count { 'geometry': { '$geoIntersects': { - '$geometry': mapping( - bbox_or_pos.to_polygon(nodes_per_edge=8) - if isinstance(bbox_or_pos, BBox) # - else bbox_or_pos - ) + '$geometry': mapping(bbox_or_pos.to_polygon(nodes_per_edge=8)) + if isinstance(bbox_or_pos, BBox) + else simple_point_mapping(bbox_or_pos) } } } diff --git a/utils.py b/utils.py index f1a3aa6..b207e12 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ import anyio import httpx +from shapely import Point, get_coordinates from config import USER_AGENT @@ -56,3 +57,8 @@ def abbreviate(num: int) -> str: def get_wikimedia_commons_url(path: str) -> str: return f'https://commons.wikimedia.org/wiki/{path}' + + +def simple_point_mapping(point: Point) -> dict: + x, y = get_coordinates(point)[0] + return {'type': 'Point', 'coordinates': (x, y)}