Skip to content

Commit

Permalink
Moving ALB behind APIGW to allow service code to call apis without auth
Browse files Browse the repository at this point in the history
Co-authored-by: Dustin Sweigart <[email protected]>
  • Loading branch information
estohlmann and dustins authored Sep 5, 2024
1 parent 283dd8f commit ba8569e
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 234 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
hooks:
- id: codespell
entry: codespell
args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*', "-L=xdescribe"]
args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*,*models/*', "-L=xdescribe"]
pass_filenames: false

- repo: https://github.com/pycqa/isort
Expand Down
61 changes: 47 additions & 14 deletions lib/api-base/ecsCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
*/

// ECS Cluster Construct.
import { Duration, RemovalPolicy } from 'aws-cdk-lib';
import { BlockDeviceVolume, GroupMetrics, Monitoring } from 'aws-cdk-lib/aws-autoscaling';
import { Metric, Stats } from 'aws-cdk-lib/aws-cloudwatch';
import { InstanceType, SecurityGroup, IVpc } from 'aws-cdk-lib/aws-ec2';
import { InstanceType, IVpc, SecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { Repository } from 'aws-cdk-lib/aws-ecr';
import {
AmiHardwareType,
Expand All @@ -37,14 +37,20 @@ import {
Protocol,
Volume,
} from 'aws-cdk-lib/aws-ecs';
import { ApplicationLoadBalancer, BaseApplicationListenerProps } from 'aws-cdk-lib/aws-elasticloadbalancingv2';
import {
ApplicationLoadBalancer,
ApplicationProtocol,
BaseApplicationListenerProps,
NetworkLoadBalancer,
NetworkTargetGroup
} from 'aws-cdk-lib/aws-elasticloadbalancingv2';
import { IRole, ManagedPolicy, Role } from 'aws-cdk-lib/aws-iam';
import { StringParameter } from 'aws-cdk-lib/aws-ssm';
import { Construct } from 'constructs';

import { createCdkId } from '../core/utils';
import { BaseProps, Ec2Metadata, EcsSourceType } from '../schema';
import { ECSConfig } from '../schema';
import { BaseProps, Ec2Metadata, ECSConfig, EcsSourceType } from '../schema';
import { AlbTarget } from 'aws-cdk-lib/aws-elasticloadbalancingv2-targets';

/**
* Properties for the ECSCluster Construct.
Expand All @@ -57,6 +63,7 @@ type ECSClusterProps = {
ecsConfig: ECSConfig;
securityGroup: SecurityGroup;
vpc: IVpc;
addNlb?: boolean;
} & BaseProps;

/**
Expand All @@ -72,6 +79,10 @@ export class ECSCluster extends Construct {
/** Endpoint URL of application load balancer for the cluster. */
public readonly endpointUrl: string;

public readonly alb: ApplicationLoadBalancer;

public readonly nlb: NetworkLoadBalancer;

/**
* @param {Construct} scope - The parent or owner of the construct.
* @param {string} id - The unique identifier for the construct within its scope.
Expand Down Expand Up @@ -259,25 +270,47 @@ export class ECSCluster extends Construct {
service.node.addDependency(autoScalingGroup);

// Create application load balancer
const loadBalancer = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), {
this.alb = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), {
deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY,
internetFacing: ecsConfig.internetFacing,
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier, 'ALB'], 32, 2),
dropInvalidHeaderFields: true,
securityGroup,
vpc,
});

if (props.addNlb) {
this.nlb = new NetworkLoadBalancer(this, createCdkId([ecsConfig.identifier, 'NLB']), {
deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY,
crossZoneEnabled: true,
internetFacing: ecsConfig.internetFacing,
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier, 'NLB'], 32, 2),
securityGroups: [securityGroup],
vpc,
});

const nlbListener = this.nlb.addListener('Listener', { port: 80 });

const albTargetGroup = new NetworkTargetGroup(this, 'ALB-Target-Group', {
port: 80,
vpc: vpc,
targets: [new AlbTarget(this.alb, 80)],
healthCheck: {
path: '/health'
}
});

nlbListener.addTargetGroups('ALB-Target-Group', albTargetGroup);
}

// Add listener
const listenerProps: BaseApplicationListenerProps = {
port: ecsConfig.loadBalancerConfig.sslCertIamArn ? 443 : 80,
port: 80,
open: ecsConfig.internetFacing,
certificates: ecsConfig.loadBalancerConfig.sslCertIamArn
? [{ certificateArn: ecsConfig.loadBalancerConfig.sslCertIamArn }]
: undefined,
protocol: ApplicationProtocol.HTTP
};

const listener = loadBalancer.addListener(
const listener = this.alb.addListener(
createCdkId([ecsConfig.identifier, 'ApplicationListener']),
listenerProps,
);
Expand Down Expand Up @@ -305,7 +338,7 @@ export class ECSCluster extends Construct {
namespace: 'AWS/ApplicationELB',
dimensionsMap: {
TargetGroup: targetGroup.targetGroupFullName,
LoadBalancer: loadBalancer.loadBalancerFullName,
LoadBalancer: this.alb.loadBalancerFullName,
},
statistic: Stats.SAMPLE_COUNT,
period: Duration.seconds(ecsConfig.autoScalingConfig.metricConfig.duration),
Expand All @@ -321,7 +354,7 @@ export class ECSCluster extends Construct {
const domain =
ecsConfig.loadBalancerConfig.domainName !== null
? ecsConfig.loadBalancerConfig.domainName
: loadBalancer.loadBalancerDnsName;
: this.alb.loadBalancerDnsName;
const endpoint = `${protocol}://${domain}`;
this.endpointUrl = endpoint;

Expand Down
67 changes: 56 additions & 11 deletions lib/api-base/fastApiContainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ import { dump as yamlDump } from 'js-yaml';

import { ECSCluster } from './ecsCluster';
import { BaseProps, Ec2Metadata, EcsSourceType, FastApiContainerConfig } from '../schema';
import {
ConnectionType,
Cors,
IAuthorizer,
Integration,
IntegrationType,
RestApi,
VpcLink
} from 'aws-cdk-lib/aws-apigateway';

// This is the amount of memory to buffer (or subtract off) from the total instance memory, if we don't include this,
// the container can have a hard time finding available RAM resources to start and the tasks will fail deployment
Expand All @@ -36,6 +45,9 @@ const CONTAINER_MEMORY_BUFFER = 1024 * 2;
* @property {SecurityGroup} securityGroups - The security groups of the application.
*/
type FastApiContainerProps = {
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
apiName: string;
resourcePath: string;
securityGroup: SecurityGroup;
Expand Down Expand Up @@ -82,18 +94,9 @@ export class FastApiContainer extends Construct {
AWS_REGION_NAME: config.region, // for supporting SageMaker endpoints in LiteLLM
THREADS: Ec2Metadata.get(taskConfig.instanceType).vCpus.toString(),
LITELLM_KEY: config.litellmConfig.general_settings.master_key,
USE_AUTH: 'false',
};

if (config.restApiConfig.internetFacing) {
environment.USE_AUTH = 'true';
environment.AUTHORITY = config.authConfig!.authority;
environment.CLIENT_ID = config.authConfig!.clientId;
environment.ADMIN_GROUP = config.authConfig!.adminGroup;
environment.JWT_GROUPS_PROP = config.authConfig!.jwtGroupsProperty;
} else {
environment.USE_AUTH = 'false';
}

if (tokenTable) {
environment.TOKEN_TABLE_NAME = tokenTable.tableName;
}
Expand All @@ -109,16 +112,58 @@ export class FastApiContainer extends Construct {
environment,
identifier: props.apiName,
instanceType: taskConfig.instanceType,
internetFacing: config.restApiConfig.internetFacing,
internetFacing: false,
loadBalancerConfig: taskConfig.loadBalancerConfig,
},
securityGroup,
vpc,
addNlb: true
});

const nlbVpcLink = new VpcLink(this, 'nlb-vpc-link', {
targets: [apiCluster.nlb]
});

// get the rest api
const restApi = RestApi.fromRestApiAttributes(this, 'RestApi', {
restApiId: props.restApiId,
rootResourceId: props.rootResourceId,
});

const integration = new Integration({
type: IntegrationType.HTTP_PROXY,
integrationHttpMethod: 'ANY',
options: {
connectionType: ConnectionType.VPC_LINK,
vpcLink: nlbVpcLink,
requestParameters: {
'integration.request.path.proxy': 'method.request.path.proxy'
},
},
uri: `${apiCluster.endpointUrl}/{proxy}`,
});

// create the proxy
const resource = restApi.root.addResource('llm').addProxy({
defaultIntegration: integration,
anyMethod: true,
defaultMethodOptions: {
authorizer: props.authorizer,
requestParameters: {
'method.request.path.proxy': true
}
}
});

resource.addCorsPreflight({
allowOrigins: Cors.ALL_ORIGINS,
allowHeaders: ['*'],
});

if (tokenTable) {
tokenTable.grantReadData(apiCluster.taskRole);
}

this.endpoint = apiCluster.endpointUrl;

// Update
Expand Down
8 changes: 1 addition & 7 deletions lib/networking/vpc/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,7 @@ export class Vpc extends Construct {
// All HTTP VPC traffic -> ECS model ALB
ecsModelAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');

if (config.restApiConfig.loadBalancerConfig.sslCertIamArn) {
// All HTTPS IPV4 traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.anyIpv4(), Port.tcp(443), 'Allow any traffic on port 443');
} else {
// All HTTP VPC traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');
}
restApiAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');

// Update
this.vpc = vpc;
Expand Down
7 changes: 7 additions & 0 deletions lib/serve/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ import { FastApiContainer } from '../api-base/fastApiContainer';
import { createCdkId } from '../core/utils';
import { Vpc } from '../networking/vpc';
import { BaseProps } from '../schema';
import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway';

const HERE = path.resolve(__dirname);

type CustomLisaStackProps = {
vpc: Vpc;
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
} & BaseProps;
type LisaStackProps = CustomLisaStackProps & StackProps;

Expand Down Expand Up @@ -73,6 +77,9 @@ export class LisaServeApplicationStack extends Stack {

// Create REST API
const restApi = new FastApiContainer(this, 'RestApi', {
authorizer: props.authorizer,
restApiId: props.restApiId,
rootResourceId: props.rootResourceId,
apiName: 'REST',
config: config,
resourcePath: path.join(HERE, 'rest-api'),
Expand Down
22 changes: 13 additions & 9 deletions lib/stages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,28 @@ export class LisaServeApplicationStage extends Stage {
});
stacks.push(coreStack);

const apiBaseStack = new LisaApiBaseStack(this, 'LisaApiBase', {
...baseStackProps,
stackName: createCdkId([config.deploymentName, config.appName, 'API']),
description: `LISA-API: ${config.deploymentName}-${config.deploymentStage}`,
vpc: networkingStack.vpc.vpc,
});
apiBaseStack.addDependency(coreStack);
stacks.push(apiBaseStack);

const serveStack = new LisaServeApplicationStack(this, 'LisaServe', {
...baseStackProps,
authorizer: apiBaseStack.authorizer,
restApiId: apiBaseStack.restApiId,
rootResourceId: apiBaseStack.rootResourceId,
description: `LISA-serve: ${config.deploymentName}-${config.deploymentStage}`,
stackName: createCdkId([config.deploymentName, config.appName, 'serve', config.deploymentStage]),
vpc: networkingStack.vpc,
});
stacks.push(serveStack);

serveStack.addDependency(iamStack);

const apiBaseStack = new LisaApiBaseStack(this, 'LisaApiBase', {
...baseStackProps,
stackName: createCdkId([config.deploymentName, config.appName, 'API']),
description: `LISA-API: ${config.deploymentName}-${config.deploymentStage}`,
vpc: networkingStack.vpc.vpc,
});
apiBaseStack.addDependency(coreStack);
stacks.push(apiBaseStack);
serveStack.addDependency(apiBaseStack);

const apiDeploymentStack = new LisaApiDeploymentStack(this, 'LisaApiDeployment', {
...baseStackProps,
Expand Down
8 changes: 1 addition & 7 deletions lib/user-interface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ import { ManagedPolicy, Role, ServicePrincipal } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { BlockPublicAccess, Bucket, BucketEncryption } from 'aws-cdk-lib/aws-s3';
import { BucketDeployment, Source } from 'aws-cdk-lib/aws-s3-deployment';
import { StringParameter } from 'aws-cdk-lib/aws-ssm';
import { Construct } from 'constructs';

import { createCdkId } from '../core/utils';
import { BaseProps } from '../schema';

/**
Expand Down Expand Up @@ -187,11 +185,7 @@ export class UserInterfaceStack extends Stack {
ADMIN_GROUP: config.authConfig!.adminGroup,
JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty,
CUSTOM_SCOPES: config.authConfig!.additionalScopes,
RESTAPI_URI: StringParameter.fromStringParameterName(
this,
createCdkId(['LisaRestApiUri', 'StringParameter']),
`${config.deploymentPrefix}/lisaServeRestApiUri`,
).stringValue,
RESTAPI_ID: config.apiGatewayConfig?.domainName ? '/llm' : `/${config.deploymentStage}/llm`,
RESTAPI_VERSION: config.restApiConfig.apiVersion,
RAG_ENABLED: config.deployRag,
SYSTEM_BANNER: {
Expand Down
Loading

0 comments on commit ba8569e

Please sign in to comment.