diff --git a/lib/user-interface/index.ts b/lib/user-interface/index.ts index b7d191c5..896230b0 100644 --- a/lib/user-interface/index.ts +++ b/lib/user-interface/index.ts @@ -161,6 +161,14 @@ export class UserInterfaceStack extends Stack { }, ); + const modelsList = config.ecsModels.map((modelConfig) => { + return { + model: modelConfig.modelId, + streaming: modelConfig.streaming, + modelType: modelConfig.modelType, + }; + }); + // Website bucket deployment // Copy auth and LISA-Serve info to UI deployment bucket const appEnvConfig = { @@ -179,6 +187,7 @@ export class UserInterfaceStack extends Stack { fontColor: config.systemBanner?.fontColor, }, API_BASE_URL: config.apiGatewayConfig?.domainName ? '/' : `/${config.deploymentStage}/`, + MODELS: modelsList, }; const appEnvSource = Source.data('env.js', `window.env = ${JSON.stringify(appEnvConfig)}`); diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index dfbe89bb..33037fd0 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -136,7 +136,10 @@ export default function Chat({ sessionId }) { useEffect(() => { if (selectedModelOption) { const model = models.filter((model) => model.id === selectedModelOption.value)[0]; - setModelCanStream(true); + if (!model.streaming && streamingEnabled) { + setStreamingEnabled(false); + } + setModelCanStream(model.streaming); setSelectedModel(model); } }, [selectedModelOption, streamingEnabled]); @@ -463,8 +466,7 @@ export default function Chat({ sessionId }) { const describeTextGenModels = useCallback(async () => { setIsLoadingModels(true); - const resp = await describeModels(auth.user?.id_token); - setModels(resp.data); + setModels(describeModels('textgen')); setIsLoadingModels(false); // eslint-disable-next-line react-hooks/exhaustive-deps }, []); diff --git a/lib/user-interface/react/src/components/chatbot/RagOptions.tsx b/lib/user-interface/react/src/components/chatbot/RagOptions.tsx index 074693fb..0a72bcf3 100644 --- a/lib/user-interface/react/src/components/chatbot/RagOptions.tsx +++ b/lib/user-interface/react/src/components/chatbot/RagOptions.tsx @@ -46,11 +46,8 @@ export default function RagControls({ auth, isRunning, setUseRag, setRagConfig } useEffect(() => { setIsLoadingEmbeddingModels(true); setIsLoadingRepositories(true); - - describeModels(auth.user?.id_token).then((resp) => { - setEmbeddingModels(resp.data); - setIsLoadingEmbeddingModels(false); - }); + setEmbeddingModels(describeModels('embedding')); + setIsLoadingEmbeddingModels(false); listRagRepositories(auth.user?.id_token).then((repositories) => { setRepositoryOptions( diff --git a/lib/user-interface/react/src/components/types.tsx b/lib/user-interface/react/src/components/types.tsx index ab91b866..1f7fe031 100644 --- a/lib/user-interface/react/src/components/types.tsx +++ b/lib/user-interface/react/src/components/types.tsx @@ -109,19 +109,11 @@ export interface Repository { /** * Interface for model */ -export interface Model { +export type Model = { id: string; - object: string; - created: number; - owned_by: string; -} - -/** - * Interface for the response body received when describing a model - */ -export interface DescribeModelsResponseBody { - data: Model[]; -} + modelType: ModelTypes; + streaming?: boolean; +}; /** * Interface for creating a session request body; composed of LisaChatMessageFields diff --git a/lib/user-interface/react/src/components/utils.ts b/lib/user-interface/react/src/components/utils.ts index 078f159b..7b2fe680 100644 --- a/lib/user-interface/react/src/components/utils.ts +++ b/lib/user-interface/react/src/components/utils.ts @@ -16,11 +16,11 @@ import { LisaChatSession, - DescribeModelsResponseBody, LisaChatMessageFields, PutSessionRequestBody, LisaChatMessage, Repository, + ModelTypes, Model, } from './types'; @@ -167,12 +167,15 @@ export const deleteUserSessions = async (idToken: string) => { /** * Describes all models of a given type which are available to a user - * @param idToken the user's ID token from authenticating + * @param modelType model type we are requesting * @returns */ -export const describeModels = async (idToken: string): Promise => { - const resp = await sendAuthenticatedRequest(`${RESTAPI_URI}/${RESTAPI_VERSION}/serve/models`, 'GET', idToken); - return await resp.json(); +export const describeModels = (modelType: ModelTypes): Model[] => { + return window.env.MODELS?.filter((m) => m.modelType === modelType).map((m) => ({ + id: m.model, + streaming: m.streaming, + modelType: m.modelType, + })); }; /** diff --git a/lib/user-interface/react/src/main.tsx b/lib/user-interface/react/src/main.tsx index 6dbbcfa9..019e8e56 100644 --- a/lib/user-interface/react/src/main.tsx +++ b/lib/user-interface/react/src/main.tsx @@ -20,6 +20,7 @@ import './index.css'; import AppConfigured from './components/app-configured'; import '@cloudscape-design/global-styles/index.css'; +import { ModelTypes } from './components/types'; declare global { // eslint-disable-next-line @typescript-eslint/consistent-type-definitions @@ -36,6 +37,13 @@ declare global { backgroundColor: string; fontColor: string; }; + MODELS: [ + { + model: string; + streaming: boolean | null; + modelType: ModelTypes; + }, + ]; }; } }