Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
feat: gemini pro
Browse files Browse the repository at this point in the history
  • Loading branch information
MrlolDev committed Dec 15, 2023
1 parent c519cde commit 8175453
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 49 deletions.
36 changes: 23 additions & 13 deletions src/bot/commands/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,26 @@ async function buildInfo(
const history = conversation?.history ?? {
messages: [],
};
const data: {
messages: { role: string; content: string }[];
max_tokens?: number;
temperature?: number;
model?: string;
} = {
max_tokens: premium ? 500 : 300,
messages: [
...history.messages,
{
role: "user",
content: prompt,
},
],
};
if (modelName === "gemini") {
data.model = "gemini-pro";
}
try {
const event = await model.run(bot.api, {
max_tokens: premium ? 500 : 300,
messages: [
...history.messages,
{
role: "user",
content: prompt,
},
],
});
const event = await model.run(bot.api, data);
if (conversation) {
await addMessageToConversation(conversation, {
role: "user",
Expand Down Expand Up @@ -158,8 +167,9 @@ async function buildInfo(
// if last update was more than 1 second ago
lastUpdate = Date.now();
await edit({
content: `${data.result}<${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${loadingIndicator.emoji.id
}>`,
content: `${data.result}<${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${
loadingIndicator.emoji.id
}>`,
});
}
} else {
Expand All @@ -172,7 +182,7 @@ async function buildInfo(
}
// if last update was less than 1 second ago, wait 1 second
if (lastUpdate + 1000 > Date.now()) await delay(1000);
await chargePlan(data.cost, env, "image", modelName);
await chargePlan(data.cost, env, "chat", modelName);

await edit({
content: `${data.result}`,
Expand Down
10 changes: 6 additions & 4 deletions src/bot/commands/imagine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ export default createCommand({
embeds: [
{
color: config.brand.color,
title: `Waiting in queue <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${loadingIndicator.emoji.id
}>`,
title: `Waiting in queue <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${
loadingIndicator.emoji.id
}>`,
},
],
});
Expand All @@ -186,8 +187,9 @@ export default createCommand({
embeds: [
{
color: config.brand.color,
title: `Generating <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${loadingIndicator.emoji.id
}>`,
title: `Generating <${loadingIndicator.emoji.animated ? "a" : ""}:${loadingIndicator.emoji.name}:${
loadingIndicator.emoji.id
}>`,
},
],
});
Expand Down
10 changes: 8 additions & 2 deletions src/bot/commands/reset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import config from "../../config.js";
import { NoCooldown, buttonInfo, createCommand } from "../config/setup.js";
import { gatewayConfig } from "../index.js";
import { resetConversation } from "../utils/conversations.js";
import { getDefaultValues, getSettingsValue } from "../utils/settings.js";

export default createCommand({
body: {
Expand All @@ -11,7 +12,7 @@ export default createCommand({
},
cooldown: NoCooldown,
isPrivate: true,
interaction: async ({ interaction }) => {
interaction: async ({ interaction, env }) => {
await interaction.edit({
embeds: [
{
Expand All @@ -21,7 +22,12 @@ export default createCommand({
},
],
});
await resetConversation(interaction.user.id.toString(), "openchat");
const user = env.user;
let setting = (await getSettingsValue(user, "chat:model")) as string;
if (!setting) {
setting = (await getDefaultValues("chat:model")) as string;
}
await resetConversation(interaction.user.id.toString(), setting);
await interaction.edit({
embeds: [
{
Expand Down
3 changes: 3 additions & 0 deletions src/bot/models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { sdxl, OpenJourneyDiffussion, Deliberate, majicMIXR } from "./stablehord
import kandinsky from "./kandinsky.js";
import { Zephyr } from "./text/pawan.js";
import fastSdxl from "./fast-sdxl.js";
import google from "./text/google.js";

type Prettify<T> = {
[K in keyof T]: T[K];
Expand All @@ -28,6 +29,7 @@ export type ChatModel = Prettify<
emoji: { name: string; id: string };
maxTokens: 2048 | 4096 | 8192;
premium?: boolean;
model?: string;
}
>;

Expand Down Expand Up @@ -90,6 +92,7 @@ export const CHAT_MODELS: (GPTModel | AnthropicModel | OpenChatModel)[] = [
Claude_instant,
openchat,
Zephyr,
google,
];
export type GenericParam = Parameters<Api["image"]["sh"]>[0];

Expand Down
14 changes: 14 additions & 0 deletions src/bot/models/text/google.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { OpenChatModel } from "../index.js";

export default {
id: "gemini",
name: "Gemini",
description: "Last Google Large Language Model",
emoji: { name: "gemini", id: "1185280770877698048" },
maxTokens: 4096,
run: async (api, data) => {
return await api.text.google({
...data,
});
},
} as OpenChatModel;
34 changes: 17 additions & 17 deletions src/bot/utils/conversations.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import { delay } from "@discordeno/utils";
import { Conversation, ConversationHistory, ConversationMessage } from "../../types/models/conversations.js";
import { get, insert, update } from "./db.js";

export async function getConversation(userId: string, modelName: string) {
const conversation = (await get({
collection: "conversations",
filter: {
id: `${userId}-${modelName}`,
user: userId,
model: modelName,
},
})) as Conversation[];
if (conversation.length === 0) return null;
return conversation[0];
id: `${userId}-${modelName}`,
})) as Conversation;
if (!conversation) return null;
return conversation;
}

export async function addMessageToConversation(conversation: Conversation, message: ConversationMessage) {
Expand All @@ -32,21 +29,24 @@ export async function addMessageToConversation(conversation: Conversation, messa
}

export async function newConversation(message: ConversationMessage, userId: string, modelName: string) {
const newConversation = {
history: {
datasetId: "",
messages: [message],
},
last_update: Date.now(),
model: modelName,
user: userId,
id: `${userId}-${modelName}`,
} as Conversation;
await insert(
"conversations",
{
history: {
datasetId: "",
messages: [message],
},
last_update: Date.now(),
model: modelName,
user: userId,
...newConversation,
},
`${userId}-${modelName}`,
);
const conversation = await getConversation(userId, modelName);
return conversation;
return newConversation;
}

export async function resetConversation(userId: string, modelName: string) {
Expand Down
30 changes: 18 additions & 12 deletions src/bot/utils/premium.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,19 @@ export async function generatePremiumEmbed(premiumInfo: {
for (const expense of user.plan.expenses) {
if (expensesFields.length >= 10) break;
expensesFields.push({
name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${expense.data.model
}\` - $${expense.used.toFixed(5)}`,
name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${
expense.data.model
}\` - $${expense.used.toFixed(5)}`,
value: `<t:${Math.floor(expense.time / 1000)}:R>`,
});
}
} else if (premiumInfo.premiumSelection.location === "guild" && guild?.plan) {
for (const expense of guild.plan.expenses) {
if (expensesFields.length >= 10) break;
expensesFields.push({
name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${expense.data.model
}\` - $${expense.used.toFixed(5)}`,
name: `${expense.type.slice(0, 1).toUpperCase()}${expense.type.slice(1)} - using \`${
expense.data.model
}\` - $${expense.used.toFixed(5)}`,
value: `<t:${Math.floor(expense.time / 1000)}>`,
});
}
Expand All @@ -96,17 +98,19 @@ export async function generatePremiumEmbed(premiumInfo: {
for (const charge of user.plan.history) {
if (chargesFields.length >= 10) break;
chargesFields.push({
name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${charge.gateway ? `- using \`${charge.gateway}\`` : ""
}`,
name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${
charge.gateway ? `- using \`${charge.gateway}\`` : ""
}`,
value: `$${charge.amount.toFixed(2)} - <t:${Math.floor(charge.time / 1000)}>`,
});
}
} else if (premiumInfo.premiumSelection.location === "guild" && guild?.plan) {
for (const charge of guild.plan.history) {
if (chargesFields.length >= 10) break;
chargesFields.push({
name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${charge.gateway ? `- using \`${charge.gateway}\`` : ""
}`,
name: `${charge.type.slice(0, 1).toUpperCase()}${charge.type.slice(1)} ${
charge.gateway ? `- using \`${charge.gateway}\`` : ""
}`,
value: `$${charge.amount.toFixed(2)} - <t:${Math.floor(charge.time / 1000)}>`,
});
}
Expand All @@ -122,11 +126,13 @@ export async function generatePremiumEmbed(premiumInfo: {
// LAST EMBED
let description = "";
if (premiumInfo.premiumSelection.location === "user" && user.plan) {
description = `**$${user.plan?.used.toFixed(2)}**\`${generateProgressBar(user.plan.total, user.plan.used)}\`**$${user.plan?.total
}**`;
description = `**$${user.plan?.used.toFixed(2)}**\`${generateProgressBar(user.plan.total, user.plan.used)}\`**$${
user.plan?.total
}**`;
} else if (premiumInfo.premiumSelection.location === "guild" && guild?.plan) {
description = `**$${guild.plan?.used.toFixed(2)}**\`${generateProgressBar(guild.plan.total, guild.plan.used)}\`**$${guild.plan?.total
}**`;
description = `**$${guild.plan?.used.toFixed(2)}**\`${generateProgressBar(guild.plan.total, guild.plan.used)}\`**$${
guild.plan?.total
}**`;
}
embeds.push({
title: "Your pay-as-you-go plan 📊",
Expand Down
2 changes: 1 addition & 1 deletion src/bot/utils/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ export function getDefaultValues(settingId: string) {
case "general:loadingIndicator":
return 3; // default loading indicator
case "chat:model":
return "claude-instant";
return "gemini";
case "chat:tone":
return "neutral";
case "chat:partialMessages":
Expand Down

0 comments on commit 8175453

Please sign in to comment.