Skip to content

Commit

Permalink
Feat/private lisa serve (#42)
Browse files Browse the repository at this point in the history
* allow vpc only deployment of LISA-serve

* fix DDB API token table

* fix api deployment stack name

* add new config options to example config

* update ddb table examples

* use https on ALB if cert present

* deploy chat and UI by default

* remove unused schema

* add deployment prefix to rag bucket for uniqueness

* add schema validation checks for auth config

* Update schema.ts
  • Loading branch information
krzim-aws authored Aug 1, 2024
1 parent b7039ca commit 445204c
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 107 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
default_language_version:
node: system
repos:
- repo: local
hooks:
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ An account owner may create a long-lived API Token using the following AWS CLI c
```bash
AWS_REGION="us-east-1" # change to your deployment region
token_string="YOUR_STRING_HERE" # change to a unique string for a user
aws --region $AWS_REGION dynamodb put-item --table-name LISAApiTokenTable \
aws --region $AWS_REGION dynamodb put-item --table-name $DEPLOYMENT_NAME-LISAApiTokenTable \
--item '{"token": {"S": "'${token_string}'"}}'
```

Expand All @@ -371,7 +371,7 @@ in seconds. The following command shows an example of how to do this.
AWS_REGION="us-east-1" # change to your deployment region
token_string="YOUR_STRING_HERE"
token_expiration=$(echo $(date +%s) + 3600 | bc) # token that expires in one hour, 3600 seconds
aws --region $AWS_REGION dynamodb put-item --table-name LISAApiTokenTable \
aws --region $AWS_REGION dynamodb put-item --table-name $DEPLOYMENT_NAME-LISAApiTokenTable \
--item '{
"token": {"S": "'${token_string}'"},
"tokenExpiration": {"N": "'${token_expiration}'"}
Expand Down Expand Up @@ -401,7 +401,7 @@ that key.
AWS_REGION="us-east-1" # change to your deployment region
token_string="YOUR_STRING_HERE"
token_expiration=$(echo $(date +%s) + 600 | bc) # token that expires in 10 minutes from now
aws --region $AWS_REGION dynamodb update-item --table-name LISAApiTokenTable \
aws --region $AWS_REGION dynamodb update-item --table-name $DEPLOYMENT_NAME-LISAApiTokenTable \
--key '{"token": {"S": "'${token_string}'"}}' \
--update-expression 'SET tokenExpiration=:t' \
--expression-attribute-values '{":t": {"N": "'${token_expiration}'"}}'
Expand All @@ -416,7 +416,7 @@ remove a token.
```bash
AWS_REGION="us-east-1" # change to your deployment region
token_string="YOUR_STRING_HERE" # change to the token to remove
aws --region $AWS_REGION dynamodb delete-item --table-name LISAApiTokenTable \
aws --region $AWS_REGION dynamodb delete-item --table-name $DEPLOYMENT_NAME-LISAApiTokenTable \
--key '{"token": {"S": "'${token_string}'"}}'
```

Expand Down Expand Up @@ -473,7 +473,7 @@ export AWS_REGION=<Region where LISA is deployed>
export AUTHORITY=<IdP Endpoint>
export CLIENT_ID=<IdP Client Id>
export REGISTERED_MODELS_PS_NAME=<Models ParameterName>
export TOKEN_TABLE_NAME="LISAApiTokenTable"
export TOKEN_TABLE_NAME="<deployment prefix>/LISAApiTokenTable"
gunicorn -k uvicorn.workers.UvicornWorker -w 2 -b "0.0.0.0:8080" "src.main:app"
```

Expand Down
3 changes: 3 additions & 0 deletions example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dev:
accountNumbersEcr:
- 012345678901
deployRag: true
deployChat: true
deployUi: true
lambdaConfig:
pythonRuntime: PYTHON_3_10
logLevel: DEBUG
Expand Down Expand Up @@ -76,6 +78,7 @@ dev:
targetValue: 1000
duration: 60
estimatedInstanceWarmup: 30
internetFacing: true
loadBalancerConfig:
sslCertIamArn: arn:aws:iam::012345678901:server-certificate/lisa-self-signed-dev
healthCheckConfig:
Expand Down
4 changes: 2 additions & 2 deletions lib/api-base/authorizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ export class CustomAuthorizer extends Construct {
memorySize: 128,
layers: [authorizerLambdaLayer, commonLambdaLayer],
environment: {
CLIENT_ID: config.authConfig.clientId,
AUTHORITY: config.authConfig.authority,
CLIENT_ID: config.authConfig!.clientId,
AUTHORITY: config.authConfig!.authority,
},
role: role,
vpc: vpc,
Expand Down
23 changes: 17 additions & 6 deletions lib/api-base/fastApiContainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ interface FastApiContainerProps extends BaseProps {
resourcePath: string;
securityGroup: SecurityGroup;
taskConfig: FastApiContainerConfig;
tokenTable: ITable;
tokenTable: ITable | undefined;
vpc: IVpc;
}

Expand Down Expand Up @@ -81,11 +81,20 @@ export class FastApiContainer extends Construct {
AWS_REGION: config.region,
AWS_REGION_NAME: config.region, // for supporting SageMaker endpoints in LiteLLM
THREADS: Ec2Metadata.get(taskConfig.instanceType).vCpus.toString(),
AUTHORITY: config.authConfig.authority,
CLIENT_ID: config.authConfig.clientId,
TOKEN_TABLE_NAME: tokenTable.tableName,
};

if (config.restApiConfig.internetFacing) {
environment.USE_AUTH = 'true';
environment.AUTHORITY = config.authConfig!.authority;
environment.CLIENT_ID = config.authConfig!.clientId;
} else {
environment.USE_AUTH = 'false';
}

if (tokenTable) {
environment.TOKEN_TABLE_NAME = tokenTable.tableName;
}

const apiCluster = new ECSCluster(scope, `${id}-ECSCluster`, {
config,
ecsConfig: {
Expand All @@ -97,14 +106,16 @@ export class FastApiContainer extends Construct {
environment,
identifier: props.apiName,
instanceType: taskConfig.instanceType,
internetFacing: true,
internetFacing: config.restApiConfig.internetFacing,
loadBalancerConfig: taskConfig.loadBalancerConfig,
},
securityGroup,
vpc,
});
tokenTable.grantReadData(apiCluster.taskRole);

if (tokenTable) {
tokenTable.grantReadData(apiCluster.taskRole);
}
this.endpoint = apiCluster.endpointUrl;

// Update
Expand Down
9 changes: 7 additions & 2 deletions lib/networking/vpc/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ export class Vpc extends Construct {
// All HTTP VPC traffic -> ECS model ALB
ecsModelAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');

// All HTTPS IPV4 traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.anyIpv4(), Port.tcp(443), 'Allow any traffic on port 443');
if (config.restApiConfig.loadBalancerConfig.sslCertIamArn) {
// All HTTPS IPV4 traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.anyIpv4(), Port.tcp(443), 'Allow any traffic on port 443');
} else {
// All HTTP VPC traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');
}

// Update
this.vpc = vpc;
Expand Down
6 changes: 3 additions & 3 deletions lib/rag/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export class LisaRagStack extends Stack {
StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/common`),
);

const bucketName = `lisaragdocs-${config.accountNumber}`;
const bucketName = `${config.deploymentName}-lisaragdocs-${config.accountNumber}`;
const bucket = new Bucket(this, createCdkId(['LISA', 'RAG', config.deploymentName, config.deploymentStage]), {
bucketName,
cors: [
Expand All @@ -95,8 +95,8 @@ export class LisaRagStack extends Stack {
const baseEnvironment: Record<string, string> = {
REGISTERED_MODELS_PS_NAME: modelsPs.parameterName,
BUCKET_NAME: bucketName,
CHUNK_SIZE: config.ragFileProcessingConfig.chunkSize.toString(),
CHUNK_OVERLAP: config.ragFileProcessingConfig.chunkOverlap.toString(),
CHUNK_SIZE: config.ragFileProcessingConfig!.chunkSize.toString(),
CHUNK_OVERLAP: config.ragFileProcessingConfig!.chunkOverlap.toString(),
LISA_API_URL_PS_NAME: endpointUrl.parameterName,
REST_API_VERSION: config.restApiConfig.apiVersion,
};
Expand Down
39 changes: 34 additions & 5 deletions lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -578,13 +578,15 @@ const AuthConfigSchema = z.object({
* @property {ContainerConfig} containerConfig - Configuration for the container.
* @property {AutoScalingConfigSchema} autoScalingConfig - Configuration for auto scaling settings.
* @property {LoadBalancerConfig} loadBalancerConfig - Configuration for load balancer settings.
* @property {boolean} [internetFacing=true] - Whether or not the REST API ALB will be configured as internet facing.
*/
const FastApiContainerConfigSchema = z.object({
apiVersion: z.literal('v2'),
instanceType: z.enum(VALID_INSTANCE_KEYS),
containerConfig: ContainerConfigSchema,
autoScalingConfig: AutoScalingConfigSchema,
loadBalancerConfig: LoadBalancerConfigSchema,
internetFacing: z.boolean().default(true),
});

/**
Expand Down Expand Up @@ -794,6 +796,9 @@ const LiteLLMConfig = z.object({
* @property {string} s3BucketModels - S3 bucket for models.
* @property {string} mountS3DebUrl - URL for S3-mounted Debian package.
* @property {string[]} [accountNumbersEcr=null] - List of AWS account numbers for ECR repositories.
* @property {boolean} [deployRag=false] - Whether to deploy RAG stacks.
* @property {boolean} [deployChat=true] - Whether to deploy chat stacks.
* @property {boolean} [deployUi=true] - Whether to deploy UI stacks.
* @property {string} logLevel - Log level for application.
* @property {AuthConfigSchema} authConfig - Authorization configuration.
* @property {FastApiContainerConfigSchema} restApiConfig - REST API configuration.
Expand Down Expand Up @@ -840,21 +845,23 @@ const RawConfigSchema = z
})
.optional(),
deployRag: z.boolean().optional().default(false),
deployChat: z.boolean().optional().default(true),
deployUi: z.boolean().optional().default(true),
logLevel: z.union([z.literal('DEBUG'), z.literal('INFO'), z.literal('WARNING'), z.literal('ERROR')]),
lambdaConfig: lambdaConfigSchema,
lambdaSourcePath: z.string().optional().default('./lambda'),
authConfig: AuthConfigSchema,
authConfig: AuthConfigSchema.optional(),
pypiConfig: PypiConfigSchema.optional().default({
indexUrl: '',
trustedHost: '',
}),
condaUrl: z.string().optional().default(''),
certificateAuthorityBundle: z.string().optional().default(''),
ragRepositories: z.array(RagRepositoryConfigSchema),
ragFileProcessingConfig: RagFileProcessingConfigSchema,
ragRepositories: z.array(RagRepositoryConfigSchema).default([]),
ragFileProcessingConfig: RagFileProcessingConfigSchema.optional(),
restApiConfig: FastApiContainerConfigSchema,
ecsModels: z.array(EcsModelConfigSchema),
apiGatewayConfig: ApiGatewayConfigSchema,
apiGatewayConfig: ApiGatewayConfigSchema.optional(),
nvmeHostMountPath: z.string().default('/nvme'),
nvmeContainerMountPath: z.string().default('/nvme'),
tags: z
Expand Down Expand Up @@ -895,7 +902,29 @@ const RawConfigSchema = z
})
.refine((config) => (config.pypiConfig.indexUrl && config.region.includes('iso')) || !config.region.includes('iso'), {
message: 'Must set PypiConfig if in an iso region',
});
})
.refine(
(config) => {
return !config.deployUi || config.deployChat;
},
{
message: 'Chat stack is needed for UI stack. You must set deployChat to true if deployUi is true.',
},
)
.refine(
(config) => {
return (
!(config.deployChat || config.deployRag || config.deployUi || config.restApiConfig.internetFacing) ||
config.authConfig
);
},
{
message:
'An auth config must be provided when deploying the chat, RAG, or UI stacks or when deploying an internet ' +
'facing ALB. Check that `deployChat`, `deployRag`, `deployUi`, and `restApiConfig.internetFacing` are all ' +
'false or that an `authConfig` is provided.',
},
);

/**
* Apply transformations to the raw application configuration schema.
Expand Down
25 changes: 14 additions & 11 deletions lib/serve/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,20 @@ export class LisaServeApplicationStack extends Stack {

const { config, vpc } = props;

// Create DynamoDB Table for enabling API token usage
const tokenTable = new Table(this, 'TokenTable', {
tableName: 'LISAApiTokenTable',
partitionKey: {
name: 'token',
type: AttributeType.STRING,
},
billingMode: BillingMode.PAY_PER_REQUEST,
encryption: TableEncryption.AWS_MANAGED,
removalPolicy: config.removalPolicy,
});
let tokenTable;
if (config.restApiConfig.internetFacing) {
// Create DynamoDB Table for enabling API token usage
tokenTable = new Table(this, 'TokenTable', {
tableName: `${config.deploymentName}-LISAApiTokenTable`,
partitionKey: {
name: 'token',
type: AttributeType.STRING,
},
billingMode: BillingMode.PAY_PER_REQUEST,
encryption: TableEncryption.AWS_MANAGED,
removalPolicy: config.removalPolicy,
});
}

// Create REST API
const restApi = new FastApiContainer(this, 'RestApi', {
Expand Down
23 changes: 14 additions & 9 deletions lib/serve/rest-api/src/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Model information routes."""

import logging
import os

from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
Expand All @@ -25,18 +26,22 @@

logger = logging.getLogger(__name__)

security = OIDCHTTPBearer()
router = APIRouter()

router.include_router(models.router, prefix="/v1", tags=["models"], dependencies=[Depends(security)], deprecated=True)
router.include_router(
embeddings.router, prefix="/v1", tags=["embeddings"], dependencies=[Depends(security)], deprecated=True
)
router.include_router(
generation.router, prefix="/v1", tags=["generation"], dependencies=[Depends(security)], deprecated=True
)
if os.getenv("USE_AUTH", "true").lower() == "false":
dependencies = []
logger.info("Auth disabled")
else:
security = OIDCHTTPBearer()
dependencies = [Depends(security)]
logger.info("Auth enabled")


router.include_router(models.router, prefix="/v1", tags=["models"], dependencies=dependencies, deprecated=True)
router.include_router(embeddings.router, prefix="/v1", tags=["embeddings"], dependencies=dependencies, deprecated=True)
router.include_router(generation.router, prefix="/v1", tags=["generation"], dependencies=dependencies, deprecated=True)
router.include_router(
litellm_passthrough.router, prefix="/v2/serve", tags=["litellm_passthrough"], dependencies=[Depends(security)]
litellm_passthrough.router, prefix="/v2/serve", tags=["litellm_passthrough"], dependencies=dependencies
)


Expand Down
Loading

0 comments on commit 445204c

Please sign in to comment.