diff --git a/src/bot/commands/chat.ts b/src/bot/commands/chat.ts index 3654188..5ca80e1 100644 --- a/src/bot/commands/chat.ts +++ b/src/bot/commands/chat.ts @@ -114,17 +114,26 @@ async function buildInfo( const history = conversation?.history ?? { messages: [], }; + const data: { + messages: { role: string; content: string }[]; + max_tokens?: number; + temperature?: number; + model?: string; + } = { + max_tokens: premium ? 500 : 300, + messages: [ + ...history.messages, + { + role: "user", + content: prompt, + }, + ], + }; + if (modelName === "gemini") { + data.model = "gemini-pro"; + } try { - const event = await model.run(bot.api, { - max_tokens: premium ? 500 : 300, - messages: [ - ...history.messages, - { - role: "user", - content: prompt, - }, - ], - }); + const event = await model.run(bot.api, data); if (conversation) { await addMessageToConversation(conversation, { role: "user", @@ -158,8 +167,9 @@ async function buildInfo( // if last update was more than 1 second ago lastUpdate = Date.now(); await edit({ - content: `${data.result}<${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${loadingIndicator.emoji.id - }>`, + content: `${data.result}<${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${ + loadingIndicator.emoji.id + }>`, }); } } else { @@ -172,7 +182,7 @@ async function buildInfo( } // if last update was less than 1 second ago, wait 1 second if (lastUpdate + 1000 > Date.now()) await delay(1000); - await chargePlan(data.cost, env, "image", modelName); + await chargePlan(data.cost, env, "chat", modelName); await edit({ content: `${data.result}`, diff --git a/src/bot/commands/imagine.ts b/src/bot/commands/imagine.ts index 344ae41..c68d1ed 100644 --- a/src/bot/commands/imagine.ts +++ b/src/bot/commands/imagine.ts @@ -175,8 +175,9 @@ export default createCommand({ embeds: [ { color: config.brand.color, - title: `Waiting in queue <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${loadingIndicator.emoji.id - }>`, + title: `Waiting in queue <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${ + loadingIndicator.emoji.id + }>`, }, ], }); @@ -186,8 +187,9 @@ export default createCommand({ embeds: [ { color: config.brand.color, - title: `Generating <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${loadingIndicator.emoji.id - }>`, + title: `Generating <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${ + loadingIndicator.emoji.id + }>`, }, ], }); diff --git a/src/bot/commands/reset.ts b/src/bot/commands/reset.ts index 7699d08..5ede8b2 100644 --- a/src/bot/commands/reset.ts +++ b/src/bot/commands/reset.ts @@ -3,6 +3,7 @@ import config from "../../config.js"; import { NoCooldown, buttonInfo, createCommand } from "../config/setup.js"; import { gatewayConfig } from "../index.js"; import { resetConversation } from "../utils/conversations.js"; +import { getDefaultValues, getSettingsValue } from "../utils/settings.js"; export default createCommand({ body: { @@ -11,7 +12,7 @@ export default createCommand({ }, cooldown: NoCooldown, isPrivate: true, - interaction: async ({ interaction }) => { + interaction: async ({ interaction, env }) => { await interaction.edit({ embeds: [ { @@ -21,7 +22,12 @@ export default createCommand({ }, ], }); - await resetConversation(interaction.user.id.toString(), "openchat"); + const user = env.user; + let setting = (await getSettingsValue(user, "chat:model")) as string; + if (!setting) { + setting = (await getDefaultValues("chat:model")) as string; + } + await resetConversation(interaction.user.id.toString(), setting); await interaction.edit({ embeds: [ { diff --git a/src/bot/models/index.ts b/src/bot/models/index.ts index f07d94a..9780483 100644 --- a/src/bot/models/index.ts +++ b/src/bot/models/index.ts @@ -7,6 +7,7 @@ import { sdxl, OpenJourneyDiffussion, Deliberate, majicMIXR } from "./stablehord import kandinsky from "./kandinsky.js"; import { Zephyr } from "./text/pawan.js"; import fastSdxl from "./fast-sdxl.js"; +import google from "./text/google.js"; type Prettify = { [K in keyof T]: T[K]; @@ -28,6 +29,7 @@ export type ChatModel = Prettify< emoji: { name: string; id: string }; maxTokens: 2048 | 4096 | 8192; premium?: boolean; + model?: string; } >; @@ -90,6 +92,7 @@ export const CHAT_MODELS: (GPTModel | AnthropicModel | OpenChatModel)[] = [ Claude_instant, openchat, Zephyr, + google, ]; export type GenericParam = Parameters[0]; diff --git a/src/bot/models/text/google.ts b/src/bot/models/text/google.ts new file mode 100644 index 0000000..e409081 --- /dev/null +++ b/src/bot/models/text/google.ts @@ -0,0 +1,14 @@ +import { OpenChatModel } from "../index.js"; + +export default { + id: "gemini", + name: "Gemini", + description: "Last Google Large Language Model", + emoji: { name: "gemini", id: "1185280770877698048" }, + maxTokens: 4096, + run: async (api, data) => { + return await api.text.google({ + ...data, + }); + }, +} as OpenChatModel; diff --git a/src/bot/utils/conversations.ts b/src/bot/utils/conversations.ts index 050bd35..79fc264 100644 --- a/src/bot/utils/conversations.ts +++ b/src/bot/utils/conversations.ts @@ -1,17 +1,14 @@ +import { delay } from "@discordeno/utils"; import { Conversation, ConversationHistory, ConversationMessage } from "../../types/models/conversations.js"; import { get, insert, update } from "./db.js"; export async function getConversation(userId: string, modelName: string) { const conversation = (await get({ collection: "conversations", - filter: { - id: `${userId}-${modelName}`, - user: userId, - model: modelName, - }, - })) as Conversation[]; - if (conversation.length === 0) return null; - return conversation[0]; + id: `${userId}-${modelName}`, + })) as Conversation; + if (!conversation) return null; + return conversation; } export async function addMessageToConversation(conversation: Conversation, message: ConversationMessage) { @@ -32,21 +29,24 @@ export async function addMessageToConversation(conversation: Conversation, messa } export async function newConversation(message: ConversationMessage, userId: string, modelName: string) { + const newConversation = { + history: { + datasetId: "", + messages: [message], + }, + last_update: Date.now(), + model: modelName, + user: userId, + id: `${userId}-${modelName}`, + } as Conversation; await insert( "conversations", { - history: { - datasetId: "", - messages: [message], - }, - last_update: Date.now(), - model: modelName, - user: userId, + ...newConversation, }, `${userId}-${modelName}`, ); - const conversation = await getConversation(userId, modelName); - return conversation; + return newConversation; } export async function resetConversation(userId: string, modelName: string) { diff --git a/src/bot/utils/premium.ts b/src/bot/utils/premium.ts index c21c288..d76613b 100644 --- a/src/bot/utils/premium.ts +++ b/src/bot/utils/premium.ts @@ -69,8 +69,9 @@ export async function generatePremiumEmbed(premiumInfo: { for (const expense of user.plan.expenses) { if (expensesFields.length >= 10) break; expensesFields.push({ - name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${expense.data.model - }\` - $${expense.used.toFixed(5)}`, + name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${ + expense.data.model + }\` - $${expense.used.toFixed(5)}`, value: ``, }); } @@ -78,8 +79,9 @@ export async function generatePremiumEmbed(premiumInfo: { for (const expense of guild.plan.expenses) { if (expensesFields.length >= 10) break; expensesFields.push({ - name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${expense.data.model - }\` - $${expense.used.toFixed(5)}`, + name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${ + expense.data.model + }\` - $${expense.used.toFixed(5)}`, value: ``, }); } @@ -96,8 +98,9 @@ export async function generatePremiumEmbed(premiumInfo: { for (const charge of user.plan.history) { if (chargesFields.length >= 10) break; chargesFields.push({ - name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${charge.gateway ? `- using \`${charge.gateway}\`` : "" - }`, + name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${ + charge.gateway ? `- using \`${charge.gateway}\`` : "" + }`, value: `$${charge.amount.toFixed(2)} - `, }); } @@ -105,8 +108,9 @@ export async function generatePremiumEmbed(premiumInfo: { for (const charge of guild.plan.history) { if (chargesFields.length >= 10) break; chargesFields.push({ - name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${charge.gateway ? `- using \`${charge.gateway}\`` : "" - }`, + name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${ + charge.gateway ? `- using \`${charge.gateway}\`` : "" + }`, value: `$${charge.amount.toFixed(2)} - `, }); } @@ -122,11 +126,13 @@ export async function generatePremiumEmbed(premiumInfo: { // LAST EMBED let description = ""; if (premiumInfo.premiumSelection.location === "user" && user.plan) { - description = `**$${user.plan?.used.toFixed(2)}**\`${generateProgressBar(user.plan.total, user.plan.used)}\`**$${user.plan?.total - }**`; + description = `**$${user.plan?.used.toFixed(2)}**\`${generateProgressBar(user.plan.total, user.plan.used)}\`**$${ + user.plan?.total + }**`; } else if (premiumInfo.premiumSelection.location === "guild" && guild?.plan) { - description = `**$${guild.plan?.used.toFixed(2)}**\`${generateProgressBar(guild.plan.total, guild.plan.used)}\`**$${guild.plan?.total - }**`; + description = `**$${guild.plan?.used.toFixed(2)}**\`${generateProgressBar(guild.plan.total, guild.plan.used)}\`**$${ + guild.plan?.total + }**`; } embeds.push({ title: "Your pay-as-you-go plan 📊", diff --git a/src/bot/utils/settings.ts b/src/bot/utils/settings.ts index 8cc7e60..8247dca 100644 --- a/src/bot/utils/settings.ts +++ b/src/bot/utils/settings.ts @@ -151,7 +151,7 @@ export function getDefaultValues(settingId: string) { case "general:loadingIndicator": return 3; // default loading indicator case "chat:model": - return "claude-instant"; + return "gemini"; case "chat:tone": return "neutral"; case "chat:partialMessages":