diff --git a/api/v1/__init__.py b/api/v1/__init__.py deleted file mode 100644 index 4c23146..0000000 --- a/api/v1/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from fastapi import APIRouter - -import api.v1.countries as countries -import api.v1.node as node -import api.v1.photos as photos -import api.v1.tile as tile - -router = APIRouter(prefix='/api/v1') -router.include_router(countries.router) -router.include_router(node.router) -router.include_router(photos.router) -router.include_router(tile.router) diff --git a/api/v1/countries.py b/api/v1/countries.py index 2725c33..1a6cd5f 100644 --- a/api/v1/countries.py +++ b/api/v1/countries.py @@ -2,7 +2,8 @@ from typing import Annotated from anyio import create_task_group -from fastapi import APIRouter, Path, Response +from fastapi import APIRouter, Path +from fastapi.responses import ORJSONResponse from sentry_sdk import start_span from shapely.geometry import mapping @@ -35,50 +36,54 @@ def limit_country_names(names: dict[str, str]): return {language: name} return names - return [ - { - 'country_code': country.code, - 'country_names': limit_country_names(country.names), - 'feature_count': country_count_map[country.name], - 'data_path': f'/api/v1/countries/{country.code}.geojson', - } - for country in countries - ] + [ - { - 'country_code': 'WORLD', - 'country_names': {'default': 'World'}, - 'feature_count': sum(country_count_map.values()), - 'data_path': '/api/v1/countries/WORLD.geojson', - } - ] + return ORJSONResponse( + [ + { + 'country_code': country.code, + 'country_names': limit_country_names(country.names), + 'feature_count': country_count_map[country.name], + 'data_path': f'/api/v1/countries/{country.code}.geojson', + } + for country in countries + ] + + [ + { + 'country_code': 'WORLD', + 'country_names': {'default': 'World'}, + 'feature_count': sum(country_count_map.values()), + 'data_path': '/api/v1/countries/WORLD.geojson', + } + ] + ) @router.get('/{country_code}.geojson') @configure_cache(timedelta(hours=1), stale=timedelta(seconds=0)) -async def get_geojson( - response: Response, - country_code: Annotated[str, Path(min_length=2, max_length=5)], -): +async def get_geojson(country_code: Annotated[str, Path(min_length=2, max_length=5)]): if country_code == 'WORLD': aeds = await AEDState.get_all_aeds() else: aeds = await AEDState.get_aeds_by_country_code(country_code) - response.headers['Content-Disposition'] = 'attachment' - response.headers['Content-Type'] = 'application/geo+json; charset=utf-8' - return { - 'type': 'FeatureCollection', - 'features': [ - { - 'type': 'Feature', - 'geometry': mapping(aed.position), - 'properties': { - '@osm_type': 'node', - '@osm_id': aed.id, - '@osm_version': aed.version, - **aed.tags, - }, - } - for aed in aeds - ], - } + return ORJSONResponse( + { + 'type': 'FeatureCollection', + 'features': [ + { + 'type': 'Feature', + 'geometry': mapping(aed.position), + 'properties': { + '@osm_type': 'node', + '@osm_id': aed.id, + '@osm_version': aed.version, + **aed.tags, + }, + } + for aed in aeds + ], + }, + headers={ + 'Content-Disposition': 'attachment', + 'Content-Type': 'application/geo+json; charset=utf-8', + }, + ) diff --git a/api/v1/node.py b/api/v1/node.py index a28feaf..4b0e5af 100644 --- a/api/v1/node.py +++ b/api/v1/node.py @@ -3,6 +3,7 @@ from urllib.parse import quote_plus from fastapi import APIRouter, HTTPException +from fastapi.responses import ORJSONResponse from pytz import timezone from shapely import get_coordinates from tzfpy import get_tz @@ -88,21 +89,23 @@ async def get_node(node_id: int): '@timezone_offset': timezone_offset, } - return { - 'version': 0.6, - 'copyright': 'OpenStreetMap and contributors', - 'attribution': 'https://www.openstreetmap.org/copyright', - 'license': 'https://opendatacommons.org/licenses/odbl/1-0/', - 'elements': [ - { - **photo_dict, - **timezone_dict, - 'type': 'node', - 'id': aed.id, - 'lat': y, - 'lon': x, - 'tags': aed.tags, - 'version': aed.version, - } - ], - } + return ORJSONResponse( + { + 'version': 0.6, + 'copyright': 'OpenStreetMap and contributors', + 'attribution': 'https://www.openstreetmap.org/copyright', + 'license': 'https://opendatacommons.org/licenses/odbl/1-0/', + 'elements': [ + { + **photo_dict, + **timezone_dict, + 'type': 'node', + 'id': aed.id, + 'lat': y, + 'lon': x, + 'tags': aed.tags, + 'version': aed.version, + } + ], + } + ) diff --git a/api/v1/photos.py b/api/v1/photos.py index 309f8d7..d542e36 100644 --- a/api/v1/photos.py +++ b/api/v1/photos.py @@ -51,7 +51,7 @@ async def _fetch_image(url: str) -> tuple[bytes, str]: @router.get('/view/{id}.webp') @configure_cache(timedelta(days=365), stale=timedelta(days=365)) -async def view(id: str) -> FileResponse: +async def view(id: str): info = await PhotoState.get_photo_by_id(id) if info is None: @@ -62,7 +62,7 @@ async def view(id: str) -> FileResponse: @router.get('/proxy/direct/{url_encoded:path}') @configure_cache(timedelta(days=7), stale=timedelta(days=7)) -async def proxy_direct(url_encoded: str) -> FileResponse: +async def proxy_direct(url_encoded: str): url = unquote_plus(url_encoded) file, content_type = await _fetch_image(url) return Response(file, media_type=content_type) @@ -70,7 +70,7 @@ async def proxy_direct(url_encoded: str) -> FileResponse: @router.get('/proxy/wikimedia-commons/{path_encoded:path}') @configure_cache(timedelta(days=7), stale=timedelta(days=7)) -async def proxy_wikimedia_commons(path_encoded: str) -> FileResponse: +async def proxy_wikimedia_commons(path_encoded: str): meta_url = get_wikimedia_commons_url(unquote_plus(path_encoded)) async with get_http_client() as http: @@ -145,12 +145,12 @@ async def upload( @router.post('/report') -async def report(id: Annotated[str, Form()]) -> bool: +async def report(id: Annotated[str, Form()]): return await PhotoReportState.report_by_photo_id(id) @router.get('/report/rss.xml') -async def report_rss(request: Request) -> Response: +async def report_rss(request: Request): fg = FeedGenerator() fg.title('AED Photo Reports') fg.description('This feed contains a list of recent AED photo reports') diff --git a/config.py b/config.py index 9439c04..a2b021c 100644 --- a/config.py +++ b/config.py @@ -12,7 +12,7 @@ from sentry_sdk.integrations.pymongo import PyMongoIntegration NAME = 'openaedmap-backend' -VERSION = '2.7.5' +VERSION = '2.7.6' CREATED_BY = f'{NAME} {VERSION}' WEBSITE = 'https://openaedmap.org' diff --git a/main.py b/main.py index 88f12b5..6e7602b 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,17 @@ +import importlib +import logging +import pathlib from contextlib import asynccontextmanager from datetime import timedelta from anyio import create_task_group -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.base import BaseHTTPMiddleware -import api.v1 as v1 from config import DEFAULT_CACHE_MAX_AGE, DEFAULT_CACHE_STALE, startup_setup from middlewares.cache_middleware import CacheMiddleware -from middlewares.profiler_middleware import profiler_middleware -from middlewares.version_middleware import version_middleware +from middlewares.profiler_middleware import ProfilerMiddleware +from middlewares.version_middleware import VersionMiddleware from orjson_response import CustomORJSONResponse from states.aed_state import AEDState from states.country_state import CountryState @@ -39,9 +40,8 @@ async def lifespan(_): app = FastAPI(lifespan=lifespan, default_response_class=CustomORJSONResponse) -app.include_router(v1.router) -app.add_middleware(BaseHTTPMiddleware, dispatch=profiler_middleware) -app.add_middleware(BaseHTTPMiddleware, dispatch=version_middleware) +app.add_middleware(ProfilerMiddleware) +app.add_middleware(VersionMiddleware) app.add_middleware( CacheMiddleware, max_age=DEFAULT_CACHE_MAX_AGE, @@ -54,3 +54,28 @@ async def lifespan(_): allow_methods=['GET'], max_age=int(timedelta(days=1).total_seconds()), ) + + +def _make_router(path: str, prefix: str) -> APIRouter: + """ + Create a router from all modules in the given path. + """ + router = APIRouter(prefix=prefix) + counter = 0 + + for p in pathlib.Path(path).glob('*.py'): + module_name = p.with_suffix('').as_posix().replace('/', '.') + module = importlib.import_module(module_name) + router_attr = getattr(module, 'router', None) + + if router_attr is not None: + router.include_router(router_attr) + counter += 1 + else: + logging.warning('Missing router in %s', module_name) + + logging.info('Loaded %d routers from %s as %r', counter, path, prefix) + return router + + +app.include_router(_make_router('api/v1', '/api/v1')) diff --git a/middlewares/profiler_middleware.py b/middlewares/profiler_middleware.py index 19f1b2e..7aa2a36 100644 --- a/middlewares/profiler_middleware.py +++ b/middlewares/profiler_middleware.py @@ -1,16 +1,18 @@ from fastapi import Request from pyinstrument import Profiler +from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import HTMLResponse -# https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi -async def profiler_middleware(request: Request, call_next): - profiling = request.query_params.get('profile', False) - if profiling: - profiler = Profiler() - profiler.start() - await call_next(request) - profiler.stop() - return HTMLResponse(profiler.output_html()) - else: - return await call_next(request) +class ProfilerMiddleware(BaseHTTPMiddleware): + # https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi + async def dispatch(self, request: Request, call_next): + profiling = request.query_params.get('profile', False) + if profiling: + profiler = Profiler() + profiler.start() + await call_next(request) + profiler.stop() + return HTMLResponse(profiler.output_html()) + else: + return await call_next(request) diff --git a/middlewares/version_middleware.py b/middlewares/version_middleware.py index 718c979..47564e8 100644 --- a/middlewares/version_middleware.py +++ b/middlewares/version_middleware.py @@ -1,9 +1,11 @@ from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware from config import VERSION -async def version_middleware(request: Request, call_next): - response = await call_next(request) - response.headers['X-Version'] = VERSION - return response +class VersionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + response.headers['X-Version'] = VERSION + return response