Skip to content

Commit

Permalink
clean up functions and pub sub
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobprall committed Dec 25, 2024
1 parent b394d62 commit 4f3ba4f
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 114 deletions.
59 changes: 49 additions & 10 deletions src/packages/functions/FunctionsClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@ import { FUNCTIONS_ROOT_PATH } from '../constants'
import { getAPIUrl } from '../utils'
import { Fetch, resolveFetch } from '../utils/fetch'

/**
* FunctionInvokeOptions
* @param args - The arguments to pass to the function.
* @param headers - The headers to pass to the function.
*/
interface FunctionInvokeOptions {
args: any[]
headers?: Record<string, string>
}

/**
* FunctionsClient
* @param invoke - Invoke a function.
* @param setAuth - Set the authentication token.
*/
export class FunctionsClient {
protected url: string
protected fetch: Fetch
Expand All @@ -20,22 +35,51 @@ export class FunctionsClient {
this.fetch = resolveFetch(options.customFetch)
this.headers = options.headers ? { ...DEFAULT_HEADERS, ...options.headers } : { ...DEFAULT_HEADERS }
}
// auth token is the full connection string with apikey
// TODO: check authorization and api key setup in Gateway
setAuth(token: string) {
this.headers.Authorization = `Bearer ${token}`
}

async invoke(functionId: string, args: any[]) {
// add argument handling
async invoke(functionId: string, options: FunctionInvokeOptions) {
const { headers, args } = options
let body;
let _headers: Record<string, string> = {}
if (args &&
((headers && !Object.prototype.hasOwnProperty.call(headers, 'Content-Type')) || !headers)
) {
if (
(typeof Blob !== 'undefined' && args instanceof Blob) ||
args instanceof ArrayBuffer
) {
// will work for File as File inherits Blob
// also works for ArrayBuffer as it is the same underlying structure as a Blob
_headers['Content-Type'] = 'application/octet-stream'
body = args
} else if (typeof args === 'string') {
// plain string
_headers['Content-Type'] = 'text/plain'
body = args
} else if (typeof FormData !== 'undefined' && args instanceof FormData) {
_headers['Content-Type'] = 'multipart/form-data'
body = args
} else {
// default, assume this is JSON
_headers['Content-Type'] = 'application/json'
body = JSON.stringify(args)
}
}

try {
const response = await this.fetch(`${this.url}/${functionId}`, {
method: 'POST',
body: JSON.stringify(args),
headers: this.headers
headers: { ..._headers, ...this.headers, ...headers }
})

if (!response.ok) {
throw new SQLiteCloudError(`Failed to invoke function: ${response.statusText}`)
}

let responseType = (response.headers.get('Content-Type') ?? 'text/plain').split(';')[0].trim()
let data: any
if (responseType === 'application/json') {
Expand All @@ -47,16 +91,11 @@ export class FunctionsClient {
} else if (responseType === 'multipart/form-data') {
data = await response.formData()
} else {
// default to text
data = await response.text()
}
return { data, error: null }
return { ...data, error: null }
} catch (error) {
return { data: null, error }
}
}
}

/**
*/
48 changes: 39 additions & 9 deletions src/packages/pubsub/PubSubClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,36 @@ import { SQLiteCloudConnection } from '../../drivers/connection'
import SQLiteCloudTlsConnection from '../../drivers/connection-tls'
import { SQLiteCloudConfig } from '../../drivers/types'

/**
* PubSubCallback
* @param error - The error that occurred.
* @param results - The results of the operation.
*/
export type PubSubCallback<T = any> = (error: Error | null, results?: T) => void

/**
* ListenOptions
* @param tableName - The name of the table to listen to.
* @param dbName - The name of the database to listen to.
*/
export interface ListenOptions {
tableName: string
dbName?: string
}

/**
* PubSub
* @param listen - Listen to a channel and start to receive messages to the provided callback.
* @param unlisten - Stop receive messages from a table or channel.
* @param subscribe - Subscribe to a channel.
* @param unsubscribe - Unsubscribe from a channel.
* @param create - Create a channel.
* @param delete - Delete a channel.
* @param notify - Send a message to a channel.
* @param setPubSubOnly - Set the connection to Pub/Sub only.
* @param connected - Check if the connection is open.
* @param close - Close the connection.
*/
export interface PubSub {
listen<T>(options: ListenOptions, callback: PubSubCallback): Promise<T>
unlisten(options: ListenOptions): void
Expand All @@ -29,7 +52,7 @@ export class PubSubClient implements PubSub {
protected _pubSubConnection: SQLiteCloudConnection | null
protected defaultDatabaseName: string
protected config: SQLiteCloudConfig
// instantiate in createConnection?

constructor(config: SQLiteCloudConfig) {
this.config = config
this._pubSubConnection = null
Expand Down Expand Up @@ -71,14 +94,18 @@ export class PubSubClient implements PubSub {
}

/**
* Stop receive messages from a table or channel.
* @param entityType One of TABLE or CHANNEL
* @param entityName Name of the table or the channel
* Unlisten to a table.
* @param options Options for the unlisten operation.
*/
public unlisten(options: ListenOptions): void {
this.pubSubConnection.sql`UNLISTEN ${options.tableName} DATABASE ${options.dbName};`
}

/**
* Subscribe (listen) to a channel.
* @param channelName The name of the channel to subscribe to.
* @param callback Callback to be called when a message is received.
*/
public async subscribe(channelName: string, callback: PubSubCallback): Promise<any> {
const authCommand: string = await this.pubSubConnection.sql`LISTEN ${channelName};`

Expand All @@ -94,6 +121,10 @@ export class PubSubClient implements PubSub {
})
}

/**
* Unsubscribe (unlisten) from a channel.
* @param channelName The name of the channel to unsubscribe from.
*/
public unsubscribe(channelName: string): void {
this.pubSubConnection.sql`UNLISTEN ${channelName};`
}
Expand All @@ -104,24 +135,23 @@ export class PubSubClient implements PubSub {
* @param failIfExists Raise an error if the channel already exists
*/
public async create(channelName: string, failIfExists: boolean = true): Promise<any> {
// type this output
return await this.pubSubConnection.sql(`CREATE CHANNEL ?${failIfExists ? '' : ' IF NOT EXISTS'};`, channelName)
return await this.pubSubConnection.sql(
`CREATE CHANNEL ?${failIfExists ? '' : ' IF NOT EXISTS'};`, channelName
)
}

/**
* Deletes a Pub/Sub channel.
* @param name Channel name
*/
public async delete(channelName: string): Promise<any> {
// type this output
return await this.pubSubConnection.sql(`REMOVE CHANNEL ?;`, channelName)
return await this.pubSubConnection.sql`REMOVE CHANNEL ${channelName};`
}

/**
* Send a message to the channel.
*/
public notify(channelName: string, message: string): Promise<any> {
// type this output
return this.pubSubConnection.sql`NOTIFY ${channelName} ${message};`
}

Expand Down
190 changes: 95 additions & 95 deletions src/packages/vector/SQLiteCloudVectorClient.ts
Original file line number Diff line number Diff line change
@@ -1,95 +1,95 @@
import { Database } from "../../drivers/database";

interface Column {
name: string;
type: string;
partitionKey?: boolean;
primaryKey?: boolean;
}

interface IndexOptions {
tableName: string;
dimensions: number;
columns: Column[];
binaryQuantization?: boolean;
dbName?: string;
}

type UpsertData = [Record<string, any> & { id: string | number }][]

interface QueryOptions {
topK: number,
where?: string[]
}

interface Vector {
init(options: IndexOptions): Promise<VectorClient>
upsert(data: UpsertData): Promise<VectorClient>
query(queryEmbedding: number[], options: QueryOptions): Promise<any>
}

const DEFAULT_EMBEDDING_COLUMN_NAME = 'embedding'

const buildEmbeddingType = (dimensions: number, binaryQuantization: boolean) => {
return `${binaryQuantization ? 'BIT' : 'FLOAT'}[${dimensions}]`
}

const formatInitColumns = (opts: IndexOptions) => {
const { columns, dimensions, binaryQuantization } = opts
return columns.reduce((acc, column) => {
let _type = column.type.toLowerCase();
const { name, primaryKey, partitionKey } = column
if (_type === 'embedding') {
_type = buildEmbeddingType(dimensions, !!binaryQuantization)
}
const formattedColumn = `${name} ${_type} ${primaryKey ? 'PRIMARY KEY' : ''}${partitionKey ? 'PARTITION KEY' : ''}`
return `${acc}, ${formattedColumn}`
}, '')
}

function formatUpsertCommand(data: UpsertData): [any, any] {
throw new Error("Function not implemented.");
}


export class VectorClient implements Vector {
private _db: Database
private _tableName: string
private _columns: Column[]
private _formattedColumns: string

constructor(_db: Database) {
this._db = _db
this._tableName = ''
this._columns = []
this._formattedColumns = ''
}

async init(options: IndexOptions) {
const formattedColumns = formatInitColumns(options)
this._tableName = options.tableName
this._columns = options?.columns || []
this._formattedColumns = formattedColumns
const useDbCommand = options?.dbName ? `USE DATABASE ${options.dbName}; ` : ''
const hasTable = await this._db.sql`${useDbCommand}SELECT 1 FROM ${options.tableName} LIMIT 1;`

if (hasTable.length === 0) { // TODO - VERIFY CHECK HAS TABLE
const query = `CREATE VIRTUAL TABLE ${options.tableName} USING vec0(${formattedColumns})`
await this._db.sql(query)
}
return this
}

async upsert(data: UpsertData) {
const [formattedColumns, formattedValues] = formatUpsertCommand(data)
const query = `INSERT INTO ${this._tableName}(${formattedColumns}) VALUES (${formattedValues})`
return await this._db.sql(query)
}

async query(queryEmbedding: number[], options: QueryOptions) {
const query = `SELECT * FROM ${this._tableName} WHERE ${DEFAULT_EMBEDDING_COLUMN_NAME} match ${JSON.stringify(queryEmbedding)} and k = ${options.topK} and ${(options?.where?.join(' and ') || '')}`
const result = await this._db.sql(query)
return { data: result, error: null }
}

}
// import { Database } from "../../drivers/database";

// interface Column {
// name: string;
// type: string;
// partitionKey?: boolean;
// primaryKey?: boolean;
// }

// interface IndexOptions {
// tableName: string;
// dimensions: number;
// columns: Column[];
// binaryQuantization?: boolean;
// dbName?: string;
// }

// type UpsertData = [Record<string, any> & { id: string | number }][]

// interface QueryOptions {
// topK: number,
// where?: string[]
// }

// interface Vector {
// init(options: IndexOptions): Promise<VectorClient>
// upsert(data: UpsertData): Promise<VectorClient>
// query(queryEmbedding: number[], options: QueryOptions): Promise<any>
// }

// const DEFAULT_EMBEDDING_COLUMN_NAME = 'embedding'

// const buildEmbeddingType = (dimensions: number, binaryQuantization: boolean) => {
// return `${binaryQuantization ? 'BIT' : 'FLOAT'}[${dimensions}]`
// }

// const formatInitColumns = (opts: IndexOptions) => {
// const { columns, dimensions, binaryQuantization } = opts
// return columns.reduce((acc, column) => {
// let _type = column.type.toLowerCase();
// const { name, primaryKey, partitionKey } = column
// if (_type === 'embedding') {
// _type = buildEmbeddingType(dimensions, !!binaryQuantization)
// }
// const formattedColumn = `${name} ${_type} ${primaryKey ? 'PRIMARY KEY' : ''}${partitionKey ? 'PARTITION KEY' : ''}`
// return `${acc}, ${formattedColumn}`
// }, '')
// }

// function formatUpsertCommand(data: UpsertData): [any, any] {
// throw new Error("Function not implemented.");
// }


// export class VectorClient implements Vector {
// private _db: Database
// private _tableName: string
// private _columns: Column[]
// private _formattedColumns: string

// constructor(_db: Database) {
// this._db = _db
// this._tableName = ''
// this._columns = []
// this._formattedColumns = ''
// }

// async init(options: IndexOptions) {
// const formattedColumns = formatInitColumns(options)
// this._tableName = options.tableName
// this._columns = options?.columns || []
// this._formattedColumns = formattedColumns
// const useDbCommand = options?.dbName ? `USE DATABASE ${options.dbName}; ` : ''
// const hasTable = await this._db.sql`${useDbCommand}SELECT 1 FROM ${options.tableName} LIMIT 1;`

// if (hasTable.length === 0) { // TODO - VERIFY CHECK HAS TABLE
// const query = `CREATE VIRTUAL TABLE ${options.tableName} USING vec0(${formattedColumns})`
// await this._db.sql(query)
// }
// return this
// }

// async upsert(data: UpsertData) {
// const [formattedColumns, formattedValues] = formatUpsertCommand(data)
// const query = `INSERT INTO ${this._tableName}(${formattedColumns}) VALUES (${formattedValues})`
// return await this._db.sql(query)
// }

// async query(queryEmbedding: number[], options: QueryOptions) {
// const query = `SELECT * FROM ${this._tableName} WHERE ${DEFAULT_EMBEDDING_COLUMN_NAME} match ${JSON.stringify(queryEmbedding)} and k = ${options.topK} and ${(options?.where?.join(' and ') || '')}`
// const result = await this._db.sql(query)
// return { data: result, error: null }
// }

// }

0 comments on commit 4f3ba4f

Please sign in to comment.