Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FE to comply with OpenAI Spec #17

Merged
merged 14 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ window.env = {
CLIENT_ID: '<Your IdP Client Id Here>',
// Alternatively you can set this to be your REST api elb endpoint
RESTAPI_URI: 'http://localhost:8080/',
RESTAPI_VERSION: 'v1',
RESTAPI_VERSION: 'v2',
SESSION_REST_API_URI: '<API GW session endpoint>'
}
```
Expand Down
2 changes: 1 addition & 1 deletion example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dev:
apiGatewayConfig:
domainName:
restApiConfig:
apiVersion: v1
apiVersion: v2
instanceType: m5.large
containerConfig:
image:
Expand Down
18 changes: 8 additions & 10 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import boto3
import create_env_variables # noqa: F401
from botocore.config import Config
from lisapy.langchain import Lisa, LisaEmbeddings
from lisapy.langchain import LisaOpenAIEmbeddings
from utilities.common_functions import api_wrapper, get_id_token, retry_config
from utilities.file_processing import process_record
from utilities.vector_store import get_vector_store_client
Expand Down Expand Up @@ -68,18 +68,18 @@ def _get_cert_path() -> str | bool:
return rest_api_cert_path


def _get_embeddings(provider: str, model_name: str, id_token: str) -> LisaEmbeddings:
def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
global lisa_api_endpoint

if not lisa_api_endpoint:
lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"])
lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"]

lisa = Lisa(
url=lisa_api_endpoint, verify=_get_cert_path(), timeout=60, headers={"Authorization": f"Bearer {id_token}"}
)
headers = {"Authorization": f"Bearer {id_token}"}

embedding = LisaEmbeddings(provider=provider, model_name=model_name, client=lisa)
embedding = LisaOpenAIEmbeddings(
lisa_openai_api_base=lisa_api_endpoint + "/v2/serve", model=model_name, headers=headers
petermuller marked this conversation as resolved.
Show resolved Hide resolved
)
return embedding


Expand Down Expand Up @@ -110,14 +110,13 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]:
"""
query_string_params = event["queryStringParameters"]
model_name = query_string_params["modelName"]
model_provider = query_string_params["modelProvider"]
query = query_string_params["query"]
repository_type = query_string_params["repositoryType"]
top_k = query_string_params.get("topK", 3)

id_token = get_id_token(event)

embeddings = _get_embeddings(provider=model_provider, model_name=model_name, id_token=id_token)
embeddings = _get_embeddings(model_name=model_name, id_token=id_token)
vs = get_vector_store_client(repository_type, index=model_name, embeddings=embeddings)
docs = vs.similarity_search(
query,
Expand Down Expand Up @@ -151,7 +150,6 @@ def ingest_documents(event: dict, context: dict) -> dict:
body = json.loads(event["body"])
embedding_model = body["embeddingModel"]
model_name = embedding_model["modelName"]
model_provider = embedding_model["provider"]

query_string_params = event["queryStringParameters"]
repository_type = query_string_params["repositoryType"]
Expand All @@ -169,7 +167,7 @@ def ingest_documents(event: dict, context: dict) -> dict:
metadatas.append(doc.metadata)

id_token = get_id_token(event)
embeddings = _get_embeddings(provider=model_provider, model_name=model_name, id_token=id_token)
embeddings = _get_embeddings(model_name=model_name, id_token=id_token)
vs = get_vector_store_client(repository_type, index=model_name, embeddings=embeddings)
ids = vs.add_texts(texts=texts, metadatas=metadatas)
return {"ids": ids, "count": len(ids)}
Expand Down
1 change: 1 addition & 0 deletions lib/rag/layer/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ opensearch-py
requests-aws4auth
PyPDF2==3.0.1
langchain==0.1.0
langchain-openai==0.0.6
Comment on lines 4 to +5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads up that we may want to look into updating these since they're a bit old and the OpenAIEmbeddings client is on a deprecation path. Not a blocker for now (especially since I'm the one that committed that), but will possibly do a pass before official release to ensure we're not too old on requirements versions.

pgvector==0.2.5
psycopg2-binary==2.9.9
python-docx==1.1.0
Expand Down
2 changes: 1 addition & 1 deletion lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ const AuthConfigSchema = z.object({
* @property {LoadBalancerConfig} loadBalancerConfig - Configuration for load balancer settings.
*/
const FastApiContainerConfigSchema = z.object({
apiVersion: z.string(),
apiVersion: z.literal('v2'),
instanceType: z.enum(VALID_INSTANCE_KEYS),
containerConfig: ContainerConfigSchema,
autoScalingConfig: AutoScalingConfigSchema,
Expand Down
6 changes: 3 additions & 3 deletions lib/user-interface/react/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lib/user-interface/react/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"format": "prettier --ignore-path .gitignore --write \"**/*.+(tsx|js|ts|json)\""
},
"dependencies": {
"@langchain/core": "^0.1.22",
"@microsoft/fetch-event-source": "^2.0.1",
"langchain": "^0.1.12",
"@langchain/core": "^0.1.22",
"luxon": "^3.4.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
Expand Down
187 changes: 1 addition & 186 deletions lib/user-interface/react/src/components/adapters/lisa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,187 +14,9 @@
limitations under the License.
*/

import { GenerateRequestBody, GenerateResponseBody, GenerateStreamResponseBody, ModelKwargs } from '../types';
import { sendAuthenticatedRequest } from '../utils';
import { fetchEventSource } from '@microsoft/fetch-event-source';
import { Document } from '@langchain/core/documents';
import { BaseRetriever, BaseRetrieverInput } from '@langchain/core/retrievers';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { BaseLLMCallOptions, BaseLLMParams, LLM } from '@langchain/core/language_models/llms';

// Custom for whatever model you'll be using
export class LisaContentHandler {
contentType = 'application/json';
accepts = 'application/json';

async transformInput(
prompt: string,
modelKwargs: ModelKwargs,
modelName: string,
providerName: string,
): Promise<string> {
const payload: GenerateRequestBody = {
provider: providerName,
modelName: modelName,
modelKwargs: modelKwargs,
text: prompt,
};

return JSON.stringify(payload);
}

async transformOutput(output: GenerateResponseBody, modelKwargs: ModelKwargs): Promise<string> {
if (output.generatedText) {
if (output.finishReason === 'stop_sequence') {
for (const suffix of modelKwargs.stop_sequences) {
if (output.generatedText.endsWith(suffix)) {
// Remove the suffix and break the loop
output.generatedText = output.generatedText.substring(0, output.generatedText.length - suffix.length);
break;
}
}
}
return output.generatedText;
} else {
return '';
}
}
}

export interface LisaInput extends BaseLLMParams {
/**
* The URI of the LISA inference engine
*/
uri: string;

/**
* Key word arguments to pass to the model.
*/
modelKwargs: ModelKwargs;

/**
* Name of model to call generate on
*/
modelName: string;

/**
* Name of model provider
*/
providerName: string;

/**
* The content handler class that provides an input and output transform
* functions to handle formats between LLM and the endpoint.
*/
contentHandler: LisaContentHandler;
streaming?: boolean;
idToken: string;
}

/**
* Class for interacting with LISA Serve REST API
*/
export class Lisa extends LLM<BaseLLMCallOptions> {
static lc_name() {
return 'LISA';
}

public modelKwargs: ModelKwargs;
public modelName: string;
public providerName: string;
public streaming: boolean;
public idToken: string;
private uri: string;
private contentHandler: LisaContentHandler;

constructor(fields: LisaInput) {
super(fields);

const contentHandler = fields?.contentHandler;
if (!contentHandler) {
throw new Error(`Please pass a "contentHandler" field to the constructor`);
}

this.uri = fields.uri;
this.contentHandler = fields.contentHandler;
this.modelName = fields.modelName;
this.providerName = fields.providerName;
this.modelKwargs = fields.modelKwargs;
this.streaming = fields.streaming ?? false;
this.idToken = fields.idToken;
}

_llmType() {
return 'lisa';
}

/**
* Calls the LISA endpoint and retrieves the result.
* @param {string} prompt The input prompt.
* @param {this["ParsedCallOptions"]} options Parsed call options.
* @param {CallbackManagerForLLMRun} runManager Optional run manager.
* @returns {Promise<string>} A promise that resolves to the generated string.
*/
/** @ignore */
async _call(
prompt: string,
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun,
): Promise<string> {
return this.streaming
? await this.streamingCall(prompt, options, runManager)
: await this.noStreamingCall(prompt, options);
}

private async streamingCall(
prompt: string,
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun,
): Promise<string> {
const body = await this.contentHandler.transformInput(prompt, this.modelKwargs, this.modelName, this.providerName);
const { contentType, accepts } = this.contentHandler;
const tokens: string[] = [];
await fetchEventSource(`${this.uri}/generateStream`, {
method: 'POST',
headers: {
'Content-Type': contentType,
Accept: accepts,
Authorization: `Bearer ${this.idToken}`,
},
body: body,
async onopen(res) {
if (res.status >= 400 && res.status < 500 && res.status !== 429) {
throw res;
}
},
async onmessage(event) {
const parsedData = JSON.parse(event.data) as GenerateStreamResponseBody;
if (!parsedData.token.special && parsedData.finishReason != 'stop_sequence') {
tokens.push(parsedData.token.text);
await runManager?.handleLLMNewToken(parsedData.token.text);
}
},
onerror(err) {
throw err;
},
});
return tokens.join('');
}

private async noStreamingCall(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
void options;
const body = await this.contentHandler.transformInput(prompt, this.modelKwargs, this.modelName, this.providerName);
const { contentType, accepts } = this.contentHandler;

const response = await this.caller.call(() =>
sendAuthenticatedRequest(`${this.uri}/generate`, 'POST', this.idToken, body, {
Accept: accepts,
'Content-Type': contentType,
}),
);
return this.contentHandler.transformOutput(await response.json(), this.modelKwargs);
}
}

export interface LisaRAGRetrieverInput extends BaseRetrieverInput {
/**
Expand All @@ -207,11 +29,6 @@ export interface LisaRAGRetrieverInput extends BaseRetrieverInput {
*/
modelName: string;

/**
* Name of model provider
*/
providerName: string;

/**
* Authentication token to use when communicating with RAG API
*/
Expand All @@ -238,7 +55,6 @@ export class LisaRAGRetriever extends BaseRetriever {

private uri: string;
public modelName: string;
public providerName: string;
public idToken: string;
public repositoryId: string;
public repositoryType: string;
Expand All @@ -249,7 +65,6 @@ export class LisaRAGRetriever extends BaseRetriever {

this.uri = fields.uri;
this.modelName = fields.modelName;
this.providerName = fields.providerName;
this.idToken = fields.idToken;
this.repositoryId = fields.repositoryId;
this.repositoryType = fields.repositoryType;
Expand All @@ -258,7 +73,7 @@ export class LisaRAGRetriever extends BaseRetriever {

async _getRelevantDocuments(query: string): Promise<Document[]> {
const resp = await sendAuthenticatedRequest(
`repository/${this.repositoryId}/similaritySearch?query=${query}&modelName=${this.modelName}&modelProvider=${this.providerName}&repositoryType=${this.repositoryType}&topK=${this.topK}`,
`repository/${this.repositoryId}/similaritySearch?query=${query}&modelName=${this.modelName}&repositoryType=${this.repositoryType}&topK=${this.topK}`,
'GET',
this.idToken,
);
Expand Down
Loading
Loading