Skip to content

Commit

Permalink
Add username to RAG Document store entry (#217)
Browse files Browse the repository at this point in the history
Add username to RagDocument Table
  • Loading branch information
bedanley authored Jan 2, 2025
1 parent 73dc1ea commit 5ccb15f
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 23 deletions.
23 changes: 19 additions & 4 deletions lambda/authorizer/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,16 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i
if jwt_data := id_token_is_valid(id_token=id_token, client_id=client_id, authority=authority):
is_admin_user = is_admin(jwt_data, admin_group, jwt_groups_property)
groups = get_property_path(jwt_data, jwt_groups_property)
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=jwt_data["sub"])
allow_policy["context"] = {"username": jwt_data["sub"], "groups": json.dumps(groups or [])}
username = find_jwt_username(jwt_data)
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=username)
allow_policy["context"] = {"username": username, "groups": json.dumps(groups or [])}

if requested_resource.startswith("/models") and not is_admin_user:
# non-admin users can still list models
if event["path"].rstrip("/") != "/models":
username = jwt_data.get("sub", "user")
logger.info(f"Deny access to {username} due to non-admin accessing /models api.")
return deny_policy
if requested_resource.startswith("/configuration") and request_method == "PUT" and not is_admin_user:
username = jwt_data.get("sub", "user")
logger.info(f"Deny access to {username} due to non-admin trying to update configuration.")
return deny_policy
logger.debug(f"Generated policy: {allow_policy}")
Expand Down Expand Up @@ -160,6 +159,22 @@ def get_property_path(data: dict[str, Any], property_path: str) -> Optional[Any]
return current_node


def find_jwt_username(jwt_data: dict[str, str]) -> str:
"""Find the username in the JWT. If the key 'username' doesn't exist, return 'sub', which will be a UUID"""
username = None
if "username" in jwt_data:
username = jwt_data.get("username")
if "cognito:username" in jwt_data:
username = jwt_data.get("cognito:username")
else:
username = jwt_data.get("sub")

if not username:
raise ValueError("No username found in JWT")

return username


@cache
def get_management_tokens() -> list[str]:
"""Return secret management tokens if they exist."""
Expand Down
1 change: 1 addition & 0 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ class RagDocument(BaseModel):
collection_id: str
document_name: str
source: str
username: str
sub_docs: List[str] = Field(default_factory=lambda: [])
ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL)
upload_date: int = Field(default_factory=lambda: int(time.time()))
Expand Down
11 changes: 4 additions & 7 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lisapy.langchain import LisaOpenAIEmbeddings
from models.domain_objects import IngestionType, RagDocument
from repository.rag_document_repo import RagDocumentRepository
from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, retry_config
from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, get_username, retry_config
from utilities.exceptions import HTTPException
from utilities.file_processing import process_record
from utilities.validation import validate_model_name, ValidationError
Expand Down Expand Up @@ -361,6 +361,7 @@ def ingest_documents(event: dict, context: dict) -> dict:
chunk_overlap = int(query_string_params["chunkOverlap"]) if "chunkOverlap" in query_string_params else None
logger.info(f"using repository {repository_id}")

username = get_username(event)
repository = find_repository_by_id(repository_id)
ensure_repository_access(event, repository)

Expand Down Expand Up @@ -391,6 +392,7 @@ def ingest_documents(event: dict, context: dict) -> dict:
document_name=document_name,
source=doc_source,
sub_docs=ids,
username=username,
ingestion_type=IngestionType.MANUAL,
)
doc_repo.save(doc_entity)
Expand Down Expand Up @@ -422,7 +424,7 @@ def presigned_url(event: dict, context: dict) -> dict:
key = event["body"]

# Set derived values for conditions and fields
username = event["requestContext"]["authorizer"]["username"]
username = get_username(event)

# Conditions is an array of dictionaries.
# content-length-range restricts the size of the file uploaded
Expand All @@ -442,11 +444,6 @@ def presigned_url(event: dict, context: dict) -> dict:
return {"response": response}


def get_groups(event: Any) -> List[str]:
groups: List[str] = json.loads(event["requestContext"]["authorizer"]["groups"])
return groups


@api_wrapper
def list_docs(event: dict, context: dict) -> List[RagDocument]:
"""List all documents for a given repository/collection.
Expand Down
4 changes: 3 additions & 1 deletion lambda/repository/pipeline_ingest_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import boto3
from repository.lambda_functions import RagDocumentRepository
from utilities.common_functions import retry_config
from utilities.common_functions import get_username, retry_config
from utilities.file_processing import process_record
from utilities.validation import validate_chunk_params, validate_model_name, validate_repository_type, ValidationError
from utilities.vector_store import get_vector_store_client
Expand Down Expand Up @@ -133,6 +133,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic

# Process documents in batches
all_ids = []
username = get_username(event)
batches = batch_texts(texts, metadatas)
total_batches = len(batches)

Expand All @@ -158,6 +159,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
document_name=key,
source=docs[0][0].metadata.get("source"),
sub_docs=all_ids,
username=username,
ingestion_type=IngestionType.AUTO,
)
doc_repo.save(doc_entity)
Expand Down
3 changes: 3 additions & 0 deletions lambda/repository/state_machine/pipeline_ingest_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from models.domain_objects import IngestionType, RagDocument
from models.vectorstore import VectorStore
from repository.lambda_functions import RagDocumentRepository
from utilities.common_functions import get_username

doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])

Expand All @@ -47,6 +48,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
embedding_model = os.environ["EMBEDDING_MODEL"]
collection_name = os.environ["COLLECTION_NAME"]
repository_id = os.environ["REPOSITORY_ID"]
username = get_username(event)

# Initialize document processor and vectorstore
doc_processor = DocumentProcessor()
Expand All @@ -71,6 +73,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
document_name=key,
source=source,
sub_docs=ids,
username=username,
ingestion_type=IngestionType.AUTO,
)
doc_repo.save(doc_entity)
Expand Down
18 changes: 9 additions & 9 deletions lambda/session/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import boto3
import create_env_variables # noqa: F401
from botocore.exceptions import ClientError
from utilities.common_functions import api_wrapper, retry_config
from utilities.common_functions import api_wrapper, get_session_id, get_username, retry_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,7 +89,7 @@ def _delete_user_session(session_id: str, user_id: str) -> Dict[str, bool]:
@api_wrapper
def list_sessions(event: dict, context: dict) -> List[Dict[str, Any]]:
"""List sessions by user ID from DynamoDB."""
user_id = event["requestContext"]["authorizer"]["username"]
user_id = get_username(event)

logger.info(f"Listing sessions for user {user_id}")
return _get_all_user_sessions(user_id)
Expand All @@ -98,8 +98,8 @@ def list_sessions(event: dict, context: dict) -> List[Dict[str, Any]]:
@api_wrapper
def get_session(event: dict, context: dict) -> Dict[str, Any]:
"""Get session from DynamoDB."""
user_id = event["requestContext"]["authorizer"]["username"]
session_id = event["pathParameters"]["sessionId"]
user_id = get_username(event)
session_id = get_session_id(event)

logger.info(f"Fetching session with ID {session_id} for user {user_id}")
response = {}
Expand All @@ -116,8 +116,8 @@ def get_session(event: dict, context: dict) -> Dict[str, Any]:
@api_wrapper
def delete_session(event: dict, context: dict) -> dict:
"""Delete session from DynamoDB."""
user_id = event["requestContext"]["authorizer"]["username"]
session_id = event["pathParameters"]["sessionId"]
user_id = get_username(event)
session_id = get_session_id(event)

logger.info(f"Deleting session with ID {session_id} for user {user_id}")
return _delete_user_session(session_id, user_id)
Expand All @@ -126,7 +126,7 @@ def delete_session(event: dict, context: dict) -> dict:
@api_wrapper
def delete_user_sessions(event: dict, context: dict) -> Dict[str, bool]:
"""Delete sessions by user ID from DyanmoDB."""
user_id = event["requestContext"]["authorizer"]["username"]
user_id = get_username(event)

logger.info(f"Deleting all sessions for user {user_id}")
sessions = _get_all_user_sessions(user_id)
Expand All @@ -140,8 +140,8 @@ def delete_user_sessions(event: dict, context: dict) -> Dict[str, bool]:
@api_wrapper
def put_session(event: dict, context: dict) -> None:
"""Append the message to the record in DynamoDB."""
user_id = event["requestContext"]["authorizer"]["username"]
session_id = event["pathParameters"]["sessionId"]
user_id = get_username(event)
session_id = get_session_id(event)
# from https://stackoverflow.com/a/71446846
body = json.loads(event["body"], parse_float=Decimal)
messages = body["messages"]
Expand Down
28 changes: 26 additions & 2 deletions lambda/utilities/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile
from contextvars import ContextVar
from functools import cache
from typing import Any, Callable, Dict, TypeVar, Union
from typing import Any, Callable, Dict, List, TypeVar, Union

import boto3
from botocore.config import Config
Expand Down Expand Up @@ -135,7 +135,7 @@ def api_wrapper(f: F) -> F:
Parameters
----------
f : F
The functiont to be wrapped.
The function to be wrapped.
Returns
-------
Expand Down Expand Up @@ -344,3 +344,27 @@ def get_rest_api_container_endpoint() -> str:
lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"])
lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"]
return f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve"


def get_username(event: dict) -> str:
"""Get username from event."""
username: str = event.get("requestContext", {}).get("authorizer", {}).get("username", "system")
return username


def get_session_id(event: dict) -> str:
"""Get session_id from event."""
session_id: str = event.get("pathParameters", {}).get("sessionId")
return session_id


def get_groups(event: Any) -> List[str]:
"""Get user groups from event."""
groups: List[str] = event.get("requestContext", {}).get("authorizer", {}).get("groups", [])
return groups


def get_principal_id(event: Any) -> str:
"""Get principal from event."""
principal: str = event.get("requestContext", {}).get("authorizer", {}).get("principal", "")
return principal

0 comments on commit 5ccb15f

Please sign in to comment.