-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
92 lines (75 loc) · 3.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from db import get_exercises
from chains import enrich_chain, routing_chain, workout_plan_chain, summary_chain, exercise_info_chain
from validators import validate_workout_plan, validate_exercise_info
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
MOCK_USER_PROFILE = {
"name": "John Doe",
"age": 36,
"height": "180 cm",
"weight": "75 kg",
"other_info": "new to exercise"
}
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins, adjust this in production
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
@app.post("/api/process_query")
async def process_query(request: Request):
data = await request.json()
query = data.get("query")
print(F"User query: {query}")
try:
# Normalize the query
enriched_query = enrich_chain.invoke({"query": query})
print(f"Enriched query: {enriched_query["enriched_query"]}")
# Route the query
route = routing_chain.invoke({"question": enriched_query["enriched_query"]})
exercises = get_exercises()
exercises_list = [ex for ex in exercises]
# Process based on the route
if "WorkoutPlan" in route:
workout_plan_json = workout_plan_chain.invoke({
"user_profile": MOCK_USER_PROFILE,
"query": enriched_query,
"exercises": exercises_list
})
valid = validate_workout_plan(workout_plan_json)
if not valid:
return JSONResponse(content={"type": "error"})
summary = summary_chain.invoke({
"user_profile": MOCK_USER_PROFILE,
"workout_plan": workout_plan_json
})
workout_plan_json["summary"] = summary
return JSONResponse(content=workout_plan_json)
elif "ExerciseInfo" in route:
exercise_info_json = exercise_info_chain.invoke({
"query": enriched_query,
"exercises": exercises_list
})
valid = validate_exercise_info(exercise_info_json)
if not valid:
return JSONResponse(content={"type": "error"})
return JSONResponse(content=exercise_info_json)
else:
return JSONResponse(content={"type": "not_supported"})
except:
return JSONResponse(content={"type": "error"})
@app.get("/api/exercises")
async def get_all_exercises():
exercises = get_exercises()
return JSONResponse(content={"exercises": exercises})
app.mount("/", StaticFiles(directory="static", html=True), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)