Skip to content

Commit

Permalink
feat: Bearer auth for HTTP API (#746)
Browse files Browse the repository at this point in the history
* feat: Bearer auth for HTTP API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hguandl and pre-commit-ci[bot] authored Dec 20, 2024
1 parent 47057d2 commit d8d71b2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
8 changes: 7 additions & 1 deletion tools/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def parse_args():
help="`None` means randomized inference, otherwise deterministic.\n"
"It can't be used for fixing a timbre.",
)
parser.add_argument(
"--api_key",
type=str,
default="YOUR_API_KEY",
help="API key for authentication",
)

return parser.parse_args()

Expand Down Expand Up @@ -173,7 +179,7 @@ def parse_args():
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
stream=args.streaming,
headers={
"authorization": "Bearer YOUR_API_KEY",
"authorization": f"Bearer {args.api_key}",
"content-type": "application/msgpack",
},
)
Expand Down
32 changes: 30 additions & 2 deletions tools/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@

import pyrootutils
import uvicorn
from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
from kui.asgi import (
Depends,
FactoryClass,
HTTPException,
HttpRoute,
Kui,
OpenAPI,
Routes,
)
from kui.security import bearer_auth
from loguru import logger
from typing_extensions import Annotated

pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

Expand Down Expand Up @@ -31,7 +41,25 @@ def __init__(self):
("/v1/tts", TTSView),
("/v1/chat", ChatView),
]
self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])

def api_auth(endpoint):
async def verify(token: Annotated[str, Depends(bearer_auth)]):
if token != self.args.api_key:
raise HTTPException(401, None, "Invalid token")
return await endpoint()

async def passthrough():
return await endpoint()

if self.args.api_key is not None:
return verify
else:
return passthrough

self.routes = Routes(
[HttpRoute(path, view) for path, view in self.routes],
http_middlewares=[api_auth],
)

self.openapi = OpenAPI(
{
Expand Down
1 change: 1 addition & 0 deletions tools/server/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def parse_args():
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--api-key", type=str, default=None)

return parser.parse_args()

Expand Down

0 comments on commit d8d71b2

Please sign in to comment.