Skip to content

Commit

Permalink
Merge pull request #62 from openstreetmap-polska/dev
Browse files Browse the repository at this point in the history
Release to main
  • Loading branch information
Zaczero authored Feb 21, 2024
2 parents 4d80aa2 + a581e99 commit a7f7b6e
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 55 deletions.
12 changes: 0 additions & 12 deletions api/v1/__init__.py

This file was deleted.

20 changes: 12 additions & 8 deletions api/v1/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
from typing import Annotated

from anyio import create_task_group
from fastapi import APIRouter, Path, Response
from fastapi import APIRouter, Path
from sentry_sdk import start_span
from shapely.geometry import mapping

from middlewares.cache_middleware import configure_cache
from middlewares.skip_serialization import skip_serialization
from models.country import Country
from states.aed_state import AEDState
from states.country_state import CountryState
from utils import simple_point_mapping

router = APIRouter(prefix='/countries')


@router.get('/names')
@configure_cache(timedelta(hours=1), stale=timedelta(days=7))
@skip_serialization()
async def get_names(language: str | None = None):
countries = await CountryState.get_all_countries()
country_count_map: dict[str, int] = {}
Expand Down Expand Up @@ -55,23 +58,24 @@ def limit_country_names(names: dict[str, str]):

@router.get('/{country_code}.geojson')
@configure_cache(timedelta(hours=1), stale=timedelta(seconds=0))
async def get_geojson(
response: Response,
country_code: Annotated[str, Path(min_length=2, max_length=5)],
):
@skip_serialization(
{
'Content-Disposition': 'attachment',
'Content-Type': 'application/geo+json; charset=utf-8',
}
)
async def get_geojson(country_code: Annotated[str, Path(min_length=2, max_length=5)]):
if country_code == 'WORLD':
aeds = await AEDState.get_all_aeds()
else:
aeds = await AEDState.get_aeds_by_country_code(country_code)

response.headers['Content-Disposition'] = 'attachment'
response.headers['Content-Type'] = 'application/geo+json; charset=utf-8'
return {
'type': 'FeatureCollection',
'features': [
{
'type': 'Feature',
'geometry': mapping(aed.position),
'geometry': simple_point_mapping(aed.position),
'properties': {
'@osm_type': 'node',
'@osm_id': aed.id,
Expand Down
2 changes: 2 additions & 0 deletions api/v1/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tzfpy import get_tz

from middlewares.cache_middleware import configure_cache
from middlewares.skip_serialization import skip_serialization
from states.aed_state import AEDState
from states.photo_state import PhotoState
from utils import get_wikimedia_commons_url
Expand Down Expand Up @@ -72,6 +73,7 @@ async def _get_image_data(tags: dict[str, str]) -> dict:

@router.get('/node/{node_id}')
@configure_cache(timedelta(minutes=1), stale=timedelta(minutes=5))
@skip_serialization()
async def get_node(node_id: int):
aed = await AEDState.get_aed_by_id(node_id)

Expand Down
10 changes: 5 additions & 5 deletions api/v1/photos.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _fetch_image(url: str) -> tuple[bytes, str]:

@router.get('/view/{id}.webp')
@configure_cache(timedelta(days=365), stale=timedelta(days=365))
async def view(id: str) -> FileResponse:
async def view(id: str):
info = await PhotoState.get_photo_by_id(id)

if info is None:
Expand All @@ -62,15 +62,15 @@ async def view(id: str) -> FileResponse:

@router.get('/proxy/direct/{url_encoded:path}')
@configure_cache(timedelta(days=7), stale=timedelta(days=7))
async def proxy_direct(url_encoded: str) -> FileResponse:
async def proxy_direct(url_encoded: str):
url = unquote_plus(url_encoded)
file, content_type = await _fetch_image(url)
return Response(file, media_type=content_type)


@router.get('/proxy/wikimedia-commons/{path_encoded:path}')
@configure_cache(timedelta(days=7), stale=timedelta(days=7))
async def proxy_wikimedia_commons(path_encoded: str) -> FileResponse:
async def proxy_wikimedia_commons(path_encoded: str):
meta_url = get_wikimedia_commons_url(unquote_plus(path_encoded))

async with get_http_client() as http:
Expand Down Expand Up @@ -145,12 +145,12 @@ async def upload(


@router.post('/report')
async def report(id: Annotated[str, Form()]) -> bool:
async def report(id: Annotated[str, Form()]):
return await PhotoReportState.report_by_photo_id(id)


@router.get('/report/rss.xml')
async def report_rss(request: Request) -> Response:
async def report_rss(request: Request):
fg = FeedGenerator()
fg.title('AED Photo Reports')
fg.description('This feed contains a list of recent AED photo reports')
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sentry_sdk.integrations.pymongo import PyMongoIntegration

NAME = 'openaedmap-backend'
VERSION = '2.7.5'
VERSION = '2.7.6'
CREATED_BY = f'{NAME} {VERSION}'
WEBSITE = 'https://openaedmap.org'

Expand Down
41 changes: 33 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import importlib
import logging
import pathlib
from contextlib import asynccontextmanager
from datetime import timedelta

from anyio import create_task_group
from fastapi import FastAPI
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware

import api.v1 as v1
from config import DEFAULT_CACHE_MAX_AGE, DEFAULT_CACHE_STALE, startup_setup
from middlewares.cache_middleware import CacheMiddleware
from middlewares.profiler_middleware import profiler_middleware
from middlewares.version_middleware import version_middleware
from middlewares.profiler_middleware import ProfilerMiddleware
from middlewares.version_middleware import VersionMiddleware
from orjson_response import CustomORJSONResponse
from states.aed_state import AEDState
from states.country_state import CountryState
Expand Down Expand Up @@ -39,9 +40,8 @@ async def lifespan(_):


app = FastAPI(lifespan=lifespan, default_response_class=CustomORJSONResponse)
app.include_router(v1.router)
app.add_middleware(BaseHTTPMiddleware, dispatch=profiler_middleware)
app.add_middleware(BaseHTTPMiddleware, dispatch=version_middleware)
app.add_middleware(ProfilerMiddleware)
app.add_middleware(VersionMiddleware)
app.add_middleware(
CacheMiddleware,
max_age=DEFAULT_CACHE_MAX_AGE,
Expand All @@ -54,3 +54,28 @@ async def lifespan(_):
allow_methods=['GET'],
max_age=int(timedelta(days=1).total_seconds()),
)


def _make_router(path: str, prefix: str) -> APIRouter:
"""
Create a router from all modules in the given path.
"""
router = APIRouter(prefix=prefix)
counter = 0

for p in pathlib.Path(path).glob('*.py'):
module_name = p.with_suffix('').as_posix().replace('/', '.')
module = importlib.import_module(module_name)
router_attr = getattr(module, 'router', None)

if router_attr is not None:
router.include_router(router_attr)
counter += 1
else:
logging.warning('Missing router in %s', module_name)

logging.info('Loaded %d routers from %s as %r', counter, path, prefix)
return router


app.include_router(_make_router('api/v1', '/api/v1'))
24 changes: 13 additions & 11 deletions middlewares/profiler_middleware.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from fastapi import Request
from pyinstrument import Profiler
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import HTMLResponse


# https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi
async def profiler_middleware(request: Request, call_next):
profiling = request.query_params.get('profile', False)
if profiling:
profiler = Profiler()
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
class ProfilerMiddleware(BaseHTTPMiddleware):
# https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi
async def dispatch(self, request: Request, call_next):
profiling = request.query_params.get('profile', False)
if profiling:
profiler = Profiler()
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
20 changes: 20 additions & 0 deletions middlewares/skip_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import functools
from collections.abc import Mapping

from fastapi import Response

from orjson_response import CustomORJSONResponse


def skip_serialization(headers: Mapping[str, str] | None = None):
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
raw_response = await func(*args, **kwargs)
if isinstance(raw_response, Response):
return raw_response
return CustomORJSONResponse(raw_response, headers=headers)

return wrapper

return decorator
10 changes: 6 additions & 4 deletions middlewares/version_middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

from config import VERSION


async def version_middleware(request: Request, call_next):
response = await call_next(request)
response.headers['X-Version'] = VERSION
return response
class VersionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers['X-Version'] = VERSION
return response
10 changes: 4 additions & 6 deletions states/country_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from osm_countries import get_osm_countries
from state_utils import get_state_doc, set_state_doc
from transaction import Transaction
from utils import retry_exponential
from utils import retry_exponential, simple_point_mapping
from validators.geometry import geometry_validator


Expand Down Expand Up @@ -164,11 +164,9 @@ async def get_countries_within(cls, bbox_or_pos: BBox | Point) -> Sequence[Count
{
'geometry': {
'$geoIntersects': {
'$geometry': mapping(
bbox_or_pos.to_polygon(nodes_per_edge=8)
if isinstance(bbox_or_pos, BBox) #
else bbox_or_pos
)
'$geometry': mapping(bbox_or_pos.to_polygon(nodes_per_edge=8))
if isinstance(bbox_or_pos, BBox)
else simple_point_mapping(bbox_or_pos)
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import anyio
import httpx
from shapely import Point, get_coordinates

from config import USER_AGENT

Expand Down Expand Up @@ -56,3 +57,8 @@ def abbreviate(num: int) -> str:

def get_wikimedia_commons_url(path: str) -> str:
return f'https://commons.wikimedia.org/wiki/{path}'


def simple_point_mapping(point: Point) -> dict:
x, y = get_coordinates(point)[0]
return {'type': 'Point', 'coordinates': (x, y)}

0 comments on commit a7f7b6e

Please sign in to comment.