-
Notifications
You must be signed in to change notification settings - Fork 14
/
main.py
451 lines (372 loc) · 17.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
import asyncio
import base64
import collections
import contextlib
import copy
import datetime
import enum
import io
import json
import logging
import os
import re
import time
from os.path import dirname, join
from typing import List, Optional, Union
import asyncpg
import discord
import humanize
import numpy as np
from aiogithub import GitHub
from discord.ext import commands
from dotenv import load_dotenv
from utils.buttons import PersistentRespondView
from utils.context_managers import UserLock
from utils.decorators import event_check, in_executor, wait_ready
from utils.ipc import StellaClient, StellaAPI, StellaFile
from utils.prefix_ai import DerivativeNeuralNetwork, PrefixNeuralNetwork
from utils.useful import ListCall, StellaContext, count_source_lines, print_exception
dotenv_path = join(dirname(__file__), 'bot_settings.env')
load_dotenv(dotenv_path)
import utils.library_override
to_call = ListCall()
logging.basicConfig(level=logging.INFO)
class StellaBot(commands.Bot):
def __init__(self, **kwargs):
self.tester = kwargs.pop("tester", False)
self.help_src = kwargs.pop("help_src", None)
self.db = kwargs.pop("db", None)
self.user_db = kwargs.pop("user_db", None)
self.pass_db = kwargs.pop("pass_db", None)
self.color = kwargs.pop("color", None)
self.websocket_IP = kwargs.pop("websocket_ip")
self.stella_api = StellaAPI(self)
self.ipc_key = kwargs.pop("ipc_key")
self.ipc_port = kwargs.pop("ipc_port")
self.ipc_client = StellaClient(host=self.websocket_IP, secret_key=self.ipc_key, port=self.ipc_port)
self.git_token = kwargs.pop("git_token")
self.error_channel_id = kwargs.pop("error_channel")
self.bot_guild_id = kwargs.pop("bot_guild")
self.git = None
self.pool_pg = None
self.uptime = None
self.global_variable = None
self.all_bot_prefixes = {}
self.pending_bots = set()
self.confirmed_bots = set()
self.token = kwargs.pop("token", None)
self.blacklist = set()
self.existing_prefix = {}
self.cached_context = collections.deque(maxlen=100)
self.command_running = {}
self.user_lock = {}
self.button_click_cached = {}
self._default_prefix = kwargs.pop("default_prefix")
self._tester_prefix = kwargs.pop("tester_prefix")
self.cooldown_user_click = commands.CooldownMapping.from_cooldown(8, 10, commands.BucketType.user)
# main bot owner is kept separate
owner_ids = kwargs.pop("owner_ids")
self._stella_id, *_ = owner_ids
super().__init__(
self.get_prefix,
owner_ids=set(owner_ids),
strip_after_prefix=True,
**kwargs,
)
kweights = kwargs.pop("prefix_weights")
self.prefix_neural_network = PrefixNeuralNetwork.from_weight(*kweights.values())
self.derivative_prefix_neural = DerivativeNeuralNetwork(kwargs.pop("prefix_derivative"))
@in_executor()
def get_prefixes_dataset(self, data: List[List[Union[int, str]]]) -> np.array:
"""Get a list of prefixes from database and calculated through Neural Network"""
inputs = np.array(data)
amounts, epoch_times = inputs[:, 1].astype(np.int32), inputs[:, 2].astype(np.float)
# Normalize datasets into between 0 - 1 for ANN
# This is done by getting the the current value divided by highest value
normalized_amount, normalized_epoch = amounts / amounts.max(), epoch_times / epoch_times.max()
normalized = np.dstack((normalized_amount, normalized_epoch))
result = self.prefix_neural_network.fit(normalized) * 200
predicted = np.column_stack((inputs, result.flat[::]))
return predicted
async def add_blacklist(self, snowflake_id, reason):
timed = datetime.datetime.utcnow()
values = (snowflake_id, reason, timed)
await self.pool_pg.execute("INSERT INTO blacklist VALUES($1, $2, $3)", *values)
self.blacklist.add(snowflake_id)
payload = {
"snowflake_id": snowflake_id,
"reason": reason,
"time": timed.timestamp()
}
await self.ipc_client.request("global_blacklist_id", **payload)
async def remove_blacklist(self, snowflake_id):
await self.pool_pg.execute("DELETE FROM blacklist WHERE snowflake_id=$1", snowflake_id)
self.blacklist.remove(snowflake_id)
await self.ipc_client.request("global_unblacklist_id", snowflake_id=snowflake_id)
def get_command_signature(self, ctx: StellaContext, command_name: Union[commands.Command, str]) -> str:
if isinstance(command_name, str):
if not (command := self.get_command(command_name)):
raise Exception("Command does not exist for signature.")
else:
command = command_name
return self.help_command.get_command_signature(command, ctx)
async def after_db(self) -> None:
"""Runs after the db is connected"""
await to_call.call(self)
def add_command(self, command: commands.Command) -> None:
super().add_command(command)
command.cooldown_after_parsing = True
if not getattr(command._buckets, "_cooldown", None):
command._buckets = commands.CooldownMapping.from_cooldown(1, 5, commands.BucketType.user)
def add_user_lock(self, lock: UserLock):
self.user_lock.update({lock.user.id: lock})
async def check_user_lock(self, user: Union[discord.Member, discord.User]):
if lock := self.user_lock.get(user.id):
if lock.locked():
if isinstance(lock, UserLock):
raise lock.error
raise commands.CommandError("You can't invoke another command while another command is running.")
else:
self.user_lock.pop(user.id, None)
async def running_command(self, ctx: StellaContext, **flags):
dispatch = flags.pop("dispatch", True)
if dispatch:
self.dispatch('command', ctx)
try:
await self.check_user_lock(ctx.author)
check = await self.can_run(ctx, call_once=flags.pop("call_once", True))
if check or not flags.pop("call_check", True):
await ctx.typing()
await ctx.command.invoke(ctx)
else:
raise commands.CheckFailure('The global check once functions failed.')
except commands.CommandError as exc:
if dispatch:
await ctx.command.dispatch_error(ctx, exc)
if flags.pop("redirect_error", False):
raise
else:
if dispatch:
self.dispatch('command_completion', ctx)
finally:
self.command_running.pop(ctx.message.id, None)
async def invoke(self, ctx: StellaContext, **flags) -> None:
dispatch = flags.get("dispatch", True)
if ctx.command is not None:
run_in_task = flags.pop("in_task", True)
if run_in_task:
command_task = self.loop.create_task(self.running_command(ctx, **flags))
self.command_running.update({ctx.message.id: command_task})
else:
await self.running_command(ctx, **flags)
elif ctx.invoked_with:
exc = commands.CommandNotFound('Command "{}" is not found'.format(ctx.invoked_with))
if dispatch:
self.dispatch('command_error', ctx, exc)
if flags.pop("redirect_error", False):
raise exc
def sync_is_owner(self, user: discord.User) -> bool:
return user.id in self.owner_ids
@property
def stella(self) -> Optional[discord.User]:
"""Returns discord.User of the owner"""
return self.get_user(self._stella_id)
@property
def error_channel(self) -> discord.TextChannel:
"""Gets the error channel for the bot to log."""
return self.get_guild(self.bot_guild_id).get_channel(self.error_channel_id)
async def setup_hook(self) -> None:
await bot.stella_api.generate_token()
self.git = GitHub(self.git_token) # github uses aiohttp in init, need to put in async context
await self.after_db()
self.loop.create_task(self.after_ready())
async def after_ready(self):
await self.wait_until_ready()
if not self.tester:
self.add_view(PersistentRespondView(self))
await self.greet_server()
async def greet_server(self):
self.ipc_client(self.user.id)
try:
await self.ipc_client.subscribe()
except Exception as e:
print_exception("Failure to connect to server.", e)
else:
if data := await self.ipc_client.request("get_restart_data"):
if (channel := self.get_channel(data["channel_id"])) and isinstance(channel, discord.abc.Messageable):
message = await channel.fetch_message(data["message_id"])
message_time = discord.utils.utcnow() - message.created_at
time_taken = humanize.precisedelta(message_time)
await message.edit(content=f"Restart lasted {time_taken}")
print("Server connected.")
@to_call.append
async def loading_cog(self) -> None:
"""Loads the cog"""
exclude = "_", "."
cogs = [file for file in os.listdir("cogs") if not file.startswith(exclude)]
for cog in cogs:
name = cog[:-3] if cog.endswith(".py") else cog
try:
await self.load_extension(f"cogs.{name}")
except Exception as e:
print_exception('Ignoring exception while loading up {}:'.format(name), e)
else:
print(f"cog {name} is loaded")
await bot.load_extension("jishaku")
@to_call.append
async def fill_bots(self) -> None:
"""Fills the pending/confirmed bots in discord.py"""
for attr in "pending", "confirmed":
record = await self.pool_pg.fetch(f"SELECT bot_id FROM {attr}_bots")
setattr(self, f"{attr}_bots", set(x["bot_id"] for x in record))
print("Bots list are now filled.")
@to_call.append
async def fill_blacklist(self) -> None:
"""Loading up the blacklisted users."""
records = await self.pool_pg.fetch("SELECT snowflake_id FROM blacklist")
self.blacklist = {r["snowflake_id"] for r in records}
async def get_prefix(self, message: discord.Message) -> Union[List[str], str]:
"""A note to self: update this docstring each time i edit code.
Check if bot is in woman mode. If true, return tester prefix.
Set snowflake_id to id of guild if message originates in guild (guild object is present). Otherwise author id.
Go to cached prefixes and try to get prefix using snowflake_id i created above. If found, skip next paragraph.
If prefix is not present, select prefix field from internal_prefix postgres table using snowflake_id i created
earlier as a key then try to get prefix from returned data. If nothing was returned, use default prefix, idrc.
After doing that put resulting prefix back into in-memory cache because constant postgres lookups are no good.
Escape special characters in prefix, then compile it as regular expression using case insensivity flag (yes, i
know i could compile them in cache but has anyone asked?). Try matching the beginning of message content using
regex. If match found, return match group 0 which will be just the prefix itself. Otherwise return the stored
prefix/the default prefix.
"""
if self.tester:
return self._tester_prefix
snowflake_id = message.guild.id if message.guild else message.author.id
if (prefix := self.existing_prefix.get(snowflake_id)) is None:
data = await self.pool_pg.fetchrow(
"SELECT prefix FROM internal_prefix WHERE snowflake_id=$1",
snowflake_id,
)
prefix = self._default_prefix if data is None else data["prefix"]
self.existing_prefix[snowflake_id] = prefix
if match := re.match(re.escape(prefix), message.content, flags=re.I):
return match[0]
return prefix
def get_message(self, message_id: int) -> discord.Message:
"""Gets the message from the cache"""
return self._connection._get_message(message_id)
async def get_context(self, message: Union[discord.Message, discord.Interaction], *,
cls: Optional[commands.Context] = StellaContext) -> Union[StellaContext, commands.Context]:
"""Override get_context to use a custom Context"""
context = await super().get_context(message, cls=cls)
context.view.update_values()
return context
async def process_commands(self, message: discord.Message) -> None:
"""Override process_commands to call typing every invoke"""
if message.author.bot:
return
ctx = await self.get_context(message)
if ctx.valid and getattr(ctx.cog, "qualified_name", None) != "Jishaku":
await ctx.typing()
await self.invoke(ctx)
async def upload_file(self, *, byte: bytes, filename: str, retries: int = 4) -> StellaFile:
return await self.stella_api.upload_file(file=byte, filename=filename, retries=retries)
async def main(self) -> None:
"""Starts the bot properly"""
try:
print("Connecting to database...")
start = time.time()
pool_pg = await asyncpg.create_pool(
database=self.db,
user=self.user_db,
password=self.pass_db
)
print(f"Connected to the database ({time.time() - start})s")
except Exception as e:
print_exception("Could not connect to database:", e)
return
async with self, pool_pg:
self.uptime = datetime.datetime.utcnow()
self.pool_pg = pool_pg
await self.start(self.token)
def starter(self):
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(self.main())
async def close(self) -> None:
await super().close()
await self.stella_api.close()
intent_data = {x: True for x in ('guilds', 'members', 'emojis', 'messages', 'reactions', 'message_content')}
intents = discord.Intents(**intent_data)
with open("d_json/bot_var.json") as states_bytes:
states = json.load(states_bytes)
bot_data = {
"token": states.get("TOKEN"),
"default_prefix": states.get("DEFAULT_PREFIX", "uwu "),
"tester_prefix": states.get("TESTER_PREFIX", "?uwu "),
"bot_guild": states.get("BOT_GUILD"),
"error_channel": states.get("ERROR_CHANNEL"),
"color": 0xffcccb,
"db": states.get("DATABASE"),
"user_db": states.get("USER"),
"pass_db": states.get("PASSWORD"),
"tester": states.get("TEST"),
"help_src": states.get("HELP_SRC"),
"ipc_port": states.get("IPC_PORT"),
"ipc_key": states.get("IPC_KEY"),
"intents": intents,
"owner_ids": states.get("OWNER_IDS"),
"websocket_ip": states.get("WEBSOCKET_IP"),
"prefix_weights": states.get("PREFIX_WEIGHT"),
"prefix_derivative": states.get("PREFIX_DERIVATIVE_PATH"),
"git_token": states.get("GIT_TOKEN"),
"activity": discord.Activity(type=discord.ActivityType.listening, name="logged to my pc."),
"description": "{}'s personal bot that is partially for the public. "
f"Written with only `{count_source_lines('.'):,}` lines. plz be nice"
}
bot = StellaBot(**bot_data)
@bot.event
async def on_ready() -> None:
print("bot is ready")
@bot.event
async def on_disconnect() -> None:
print("bot disconnected")
@bot.event
async def on_connect() -> None:
print("bot connected")
@bot.event
@wait_ready(bot=bot)
@event_check(lambda m: not m.author.bot and not bot.tester or bot.sync_is_owner(m.author))
async def on_message(message: discord.Message) -> None:
if re.fullmatch(rf"<@!?{bot.user.id}>", message.content):
await message.channel.send(f"My prefix is `{await bot.get_prefix(message)}`")
return
if message.author.id in bot.blacklist or getattr(message.guild, "id", None) in bot.blacklist:
return
if await bot.is_owner(message.author) and message.attachments:
ctx = await bot.get_context(message)
if ctx.valid:
return await bot.invoke(ctx)
text_command = ["text/plain", "text/x-python"]
for a in message.attachments:
with contextlib.suppress(ValueError):
index = text_command.index(a.content_type)
attachment = await a.read()
new_message = copy.copy(message)
# Yes, i'm extremely lazy to get the command, and call the codeblock converter
# Instead, i make a new message, and make it a command.
if index:
prefix = await bot.get_prefix(message)
new_message.content = f"{prefix}jsk py ```py\n{attachment.decode('utf-8')}```"
else:
new_message.content = attachment.decode('utf-8')
await bot.process_commands(new_message)
await bot.process_commands(message)
@bot.event
async def on_command(ctx: StellaContext):
bot.cached_context.append(ctx)
@bot.before_invoke
async def on_command_before_invoke(ctx: StellaContext):
ctx.running = True
@bot.after_invoke
async def on_command_after_invoke(ctx: StellaContext):
ctx.running = False
bot.starter()