Skip to content

Commit

Permalink
Update create model validation; Ensure Base image is set for LISA hos…
Browse files Browse the repository at this point in the history
…ted models;
  • Loading branch information
estohlmann authored Dec 5, 2024
1 parent 241631b commit 23fb773
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ repos:
args:
- --max-line-length=120
- --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends
- --ignore=B008,E203 # Ignore error for function calls in argument defaults
- --ignore=B008,E203, W503 # Ignore error for function calls in argument defaults
exclude: ^(__init__.py$|.*\/__init__.py$)


Expand Down
36 changes: 36 additions & 0 deletions lambda/models/handler/create_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse:
if table_item:
raise ModelAlreadyExistsError(f"Model '{model_id}' already exists. Please select another name.")

self.validate(create_request)

self._stepfunctions.start_execution(
stateMachineArn=os.environ["CREATE_SFN_ARN"], input=create_request.model_dump_json()
)
Expand All @@ -46,3 +48,37 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse:
}
)
return CreateModelResponse(model=lisa_model)

@staticmethod
def validate(create_request: CreateModelRequest) -> None:
# The below check ensures that the model is LISA hosted
if (
create_request.containerConfig is not None
and create_request.autoScalingConfig is not None
and create_request.loadBalancerConfig is not None
):
if create_request.containerConfig.image.baseImage is None:
raise ValueError("Base image must be provided for LISA hosted model.")

# Validate values relative to current ASG. All conflicting request values have been validated as part of the
# AutoScalingInstanceConfig model validations, so those are not duplicated here.
if create_request.autoScalingConfig is not None:
# Min capacity can't be greater than the deployed ASG's max capacity
if (
create_request.autoScalingConfig.minCapacity is not None
and create_request.autoScalingConfig.maxCapacity is not None
and create_request.autoScalingConfig.minCapacity > create_request.autoScalingConfig.maxCapacity
):
raise ValueError(
f"Min capacity cannot exceed ASG max of {create_request.autoScalingConfig.maxCapacity}."
)

# Max capacity can't be less than the deployed ASG's min capacity
if (
create_request.autoScalingConfig.maxCapacity is not None
and create_request.autoScalingConfig.minCapacity is not None
and create_request.autoScalingConfig.maxCapacity < create_request.autoScalingConfig.minCapacity
):
raise ValueError(
f"Max capacity cannot be less than ASG min of {create_request.autoScalingConfig.minCapacity}."
)
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
}
}

const requiredFields = [['modelId', 'modelName'], [], [], [], []];
const requiredFields = [['modelId', 'modelName'], ['containerConfig.image.baseImage'], [], [], []];

useEffect(() => {
const parsedValue = _.mergeWith({}, initialForm, props.selectedItems[0], (a: IModelRequest, b: IModelRequest) => b === null ? a : undefined);
Expand Down Expand Up @@ -318,8 +318,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
case 'next':
case 'skip':
{
touchFields(requiredFields[state.activeStepIndex]);
if (isValid) {
if (touchFields(requiredFields[state.activeStepIndex]) && isValid) {
setState({
...state,
activeStepIndex: event.detail.requestedStepIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,16 @@ export const ModelRequestSchema = z.object({
});
}
}

const baseImageValidator = z.string().min(1, {message: 'Required for LISA hosted models.'});
const baseImageResult = baseImageValidator.safeParse(value.containerConfig.image.baseImage);
if (baseImageResult.success === false) {
for (const error of baseImageResult.error.errors) {
context.addIssue({
...error,
path: ['containerConfig', 'image', 'baseImage']
});
}
}
}
});
13 changes: 10 additions & 3 deletions lib/user-interface/react/src/shared/validation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ export type SetFieldsFunction = (
) => void;


export type TouchFieldsFunction = (fields: string[], method?: ValidationTouchActionMethod) => void;
export type TouchFieldsFunction = (fields: string[], method?: ValidationTouchActionMethod) => boolean;


/**
Expand Down Expand Up @@ -268,7 +268,7 @@ export const useValidationReducer = <F, S extends ValidationReducerBaseState<F>>
return {
state,
errors,
isValid: parseResult.success,
isValid: Object.keys(errors).length === 0,
setState: (newState: Partial<S>, method: ValidationStateActionMethod = ModifyMethod.Default) => {
setState({
type: ValidationReducerActionTypes.STATE,
Expand All @@ -289,12 +289,19 @@ export const useValidationReducer = <F, S extends ValidationReducerBaseState<F>>
touchFields: (
fields: string[],
method: ValidationTouchActionMethod = ModifyMethod.Default
) => {
): boolean => {
setState({
type: ValidationReducerActionTypes.TOUCH,
method,
fields,
} as ValidationTouchAction);
const parseResult = formSchema.safeParse({...state.form, ...{touched: fields}});
if (!parseResult.success) {
errors = issuesToErrors(parseResult.error.issues, fields.reduce((acc, key) => {
acc[key] = true; return acc;
}, {}));
}
return Object.keys(errors).length === 0;
},
};
};
Expand Down

0 comments on commit 23fb773

Please sign in to comment.