Skip to content

Commit

Permalink
Skip serialization for predictable responses
Browse files Browse the repository at this point in the history
  • Loading branch information
Zaczero committed Feb 21, 2024
1 parent 1b95988 commit 08be876
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 98 deletions.
12 changes: 0 additions & 12 deletions api/v1/__init__.py

This file was deleted.

83 changes: 44 additions & 39 deletions api/v1/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Annotated

from anyio import create_task_group
from fastapi import APIRouter, Path, Response
from fastapi import APIRouter, Path
from fastapi.responses import ORJSONResponse
from sentry_sdk import start_span
from shapely.geometry import mapping

Expand Down Expand Up @@ -35,50 +36,54 @@ def limit_country_names(names: dict[str, str]):
return {language: name}
return names

return [
{
'country_code': country.code,
'country_names': limit_country_names(country.names),
'feature_count': country_count_map[country.name],
'data_path': f'/api/v1/countries/{country.code}.geojson',
}
for country in countries
] + [
{
'country_code': 'WORLD',
'country_names': {'default': 'World'},
'feature_count': sum(country_count_map.values()),
'data_path': '/api/v1/countries/WORLD.geojson',
}
]
return ORJSONResponse(
[
{
'country_code': country.code,
'country_names': limit_country_names(country.names),
'feature_count': country_count_map[country.name],
'data_path': f'/api/v1/countries/{country.code}.geojson',
}
for country in countries
]
+ [
{
'country_code': 'WORLD',
'country_names': {'default': 'World'},
'feature_count': sum(country_count_map.values()),
'data_path': '/api/v1/countries/WORLD.geojson',
}
]
)


@router.get('/{country_code}.geojson')
@configure_cache(timedelta(hours=1), stale=timedelta(seconds=0))
async def get_geojson(
response: Response,
country_code: Annotated[str, Path(min_length=2, max_length=5)],
):
async def get_geojson(country_code: Annotated[str, Path(min_length=2, max_length=5)]):
if country_code == 'WORLD':
aeds = await AEDState.get_all_aeds()
else:
aeds = await AEDState.get_aeds_by_country_code(country_code)

response.headers['Content-Disposition'] = 'attachment'
response.headers['Content-Type'] = 'application/geo+json; charset=utf-8'
return {
'type': 'FeatureCollection',
'features': [
{
'type': 'Feature',
'geometry': mapping(aed.position),
'properties': {
'@osm_type': 'node',
'@osm_id': aed.id,
'@osm_version': aed.version,
**aed.tags,
},
}
for aed in aeds
],
}
return ORJSONResponse(
{
'type': 'FeatureCollection',
'features': [
{
'type': 'Feature',
'geometry': mapping(aed.position),
'properties': {
'@osm_type': 'node',
'@osm_id': aed.id,
'@osm_version': aed.version,
**aed.tags,
},
}
for aed in aeds
],
},
headers={
'Content-Disposition': 'attachment',
'Content-Type': 'application/geo+json; charset=utf-8',
},
)
39 changes: 21 additions & 18 deletions api/v1/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import quote_plus

from fastapi import APIRouter, HTTPException
from fastapi.responses import ORJSONResponse
from pytz import timezone
from shapely import get_coordinates
from tzfpy import get_tz
Expand Down Expand Up @@ -88,21 +89,23 @@ async def get_node(node_id: int):
'@timezone_offset': timezone_offset,
}

return {
'version': 0.6,
'copyright': 'OpenStreetMap and contributors',
'attribution': 'https://www.openstreetmap.org/copyright',
'license': 'https://opendatacommons.org/licenses/odbl/1-0/',
'elements': [
{
**photo_dict,
**timezone_dict,
'type': 'node',
'id': aed.id,
'lat': y,
'lon': x,
'tags': aed.tags,
'version': aed.version,
}
],
}
return ORJSONResponse(
{
'version': 0.6,
'copyright': 'OpenStreetMap and contributors',
'attribution': 'https://www.openstreetmap.org/copyright',
'license': 'https://opendatacommons.org/licenses/odbl/1-0/',
'elements': [
{
**photo_dict,
**timezone_dict,
'type': 'node',
'id': aed.id,
'lat': y,
'lon': x,
'tags': aed.tags,
'version': aed.version,
}
],
}
)
10 changes: 5 additions & 5 deletions api/v1/photos.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _fetch_image(url: str) -> tuple[bytes, str]:

@router.get('/view/{id}.webp')
@configure_cache(timedelta(days=365), stale=timedelta(days=365))
async def view(id: str) -> FileResponse:
async def view(id: str):
info = await PhotoState.get_photo_by_id(id)

if info is None:
Expand All @@ -62,15 +62,15 @@ async def view(id: str) -> FileResponse:

@router.get('/proxy/direct/{url_encoded:path}')
@configure_cache(timedelta(days=7), stale=timedelta(days=7))
async def proxy_direct(url_encoded: str) -> FileResponse:
async def proxy_direct(url_encoded: str):
url = unquote_plus(url_encoded)
file, content_type = await _fetch_image(url)
return Response(file, media_type=content_type)


@router.get('/proxy/wikimedia-commons/{path_encoded:path}')
@configure_cache(timedelta(days=7), stale=timedelta(days=7))
async def proxy_wikimedia_commons(path_encoded: str) -> FileResponse:
async def proxy_wikimedia_commons(path_encoded: str):
meta_url = get_wikimedia_commons_url(unquote_plus(path_encoded))

async with get_http_client() as http:
Expand Down Expand Up @@ -145,12 +145,12 @@ async def upload(


@router.post('/report')
async def report(id: Annotated[str, Form()]) -> bool:
async def report(id: Annotated[str, Form()]):
return await PhotoReportState.report_by_photo_id(id)


@router.get('/report/rss.xml')
async def report_rss(request: Request) -> Response:
async def report_rss(request: Request):
fg = FeedGenerator()
fg.title('AED Photo Reports')
fg.description('This feed contains a list of recent AED photo reports')
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sentry_sdk.integrations.pymongo import PyMongoIntegration

NAME = 'openaedmap-backend'
VERSION = '2.7.5'
VERSION = '2.7.6'
CREATED_BY = f'{NAME} {VERSION}'
WEBSITE = 'https://openaedmap.org'

Expand Down
41 changes: 33 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import importlib
import logging
import pathlib
from contextlib import asynccontextmanager
from datetime import timedelta

from anyio import create_task_group
from fastapi import FastAPI
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware

import api.v1 as v1
from config import DEFAULT_CACHE_MAX_AGE, DEFAULT_CACHE_STALE, startup_setup
from middlewares.cache_middleware import CacheMiddleware
from middlewares.profiler_middleware import profiler_middleware
from middlewares.version_middleware import version_middleware
from middlewares.profiler_middleware import ProfilerMiddleware
from middlewares.version_middleware import VersionMiddleware
from orjson_response import CustomORJSONResponse
from states.aed_state import AEDState
from states.country_state import CountryState
Expand Down Expand Up @@ -39,9 +40,8 @@ async def lifespan(_):


app = FastAPI(lifespan=lifespan, default_response_class=CustomORJSONResponse)
app.include_router(v1.router)
app.add_middleware(BaseHTTPMiddleware, dispatch=profiler_middleware)
app.add_middleware(BaseHTTPMiddleware, dispatch=version_middleware)
app.add_middleware(ProfilerMiddleware)
app.add_middleware(VersionMiddleware)
app.add_middleware(
CacheMiddleware,
max_age=DEFAULT_CACHE_MAX_AGE,
Expand All @@ -54,3 +54,28 @@ async def lifespan(_):
allow_methods=['GET'],
max_age=int(timedelta(days=1).total_seconds()),
)


def _make_router(path: str, prefix: str) -> APIRouter:
"""
Create a router from all modules in the given path.
"""
router = APIRouter(prefix=prefix)
counter = 0

for p in pathlib.Path(path).glob('*.py'):
module_name = p.with_suffix('').as_posix().replace('/', '.')
module = importlib.import_module(module_name)
router_attr = getattr(module, 'router', None)

if router_attr is not None:
router.include_router(router_attr)
counter += 1
else:
logging.warning('Missing router in %s', module_name)

logging.info('Loaded %d routers from %s as %r', counter, path, prefix)
return router


app.include_router(_make_router('api/v1', '/api/v1'))
24 changes: 13 additions & 11 deletions middlewares/profiler_middleware.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from fastapi import Request
from pyinstrument import Profiler
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import HTMLResponse


# https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi
async def profiler_middleware(request: Request, call_next):
profiling = request.query_params.get('profile', False)
if profiling:
profiler = Profiler()
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
class ProfilerMiddleware(BaseHTTPMiddleware):
# https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi
async def dispatch(self, request: Request, call_next):
profiling = request.query_params.get('profile', False)
if profiling:
profiler = Profiler()
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
10 changes: 6 additions & 4 deletions middlewares/version_middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

from config import VERSION


async def version_middleware(request: Request, call_next):
response = await call_next(request)
response.headers['X-Version'] = VERSION
return response
class VersionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers['X-Version'] = VERSION
return response

0 comments on commit 08be876

Please sign in to comment.