Skip to content

Commit

Permalink
fix: EntitlementIterator behaviour and type-hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
yoggys committed Aug 19, 2024
1 parent 8a09b89 commit 71676ba
Showing 1 changed file with 52 additions and 18 deletions.
70 changes: 52 additions & 18 deletions discord/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .types.audit_log import AuditLog as AuditLogPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload
from .types.monetization import Entitlement as EntitlementPayload
from .types.threads import Thread as ThreadPayload
from .types.user import PartialUser as PartialUserPayload
from .user import User
Expand Down Expand Up @@ -988,11 +989,21 @@ def __init__(
self.guild_id = guild_id
self.exclude_ended = exclude_ended

self._filter = None

if self.before and self.after:
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
self._filter = lambda e: int(e["id"]) > self.after.id
elif self.after:
self._retrieve_entitlements = self._retrieve_entitlements_after_strategy
else:
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy

self.state = state
self.get_entitlements = state.http.list_entitlements
self.entitlements = asyncio.Queue()

async def next(self) -> BanEntry:
async def next(self) -> Entitlement:
if self.entitlements.empty():
await self.fill_entitlements()

Expand All @@ -1014,30 +1025,53 @@ async def fill_entitlements(self):
if not self._get_retrieve():
return

data = await self._retrieve_entitlements(self.retrieve)

if self._filter:
data = list(filter(self._filter, data))

if len(data) < 100:
self.limit = 0 # terminate loop

for element in data:
await self.entitlements.put(Entitlement(data=element, state=self.state))

async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]:
"""Retrieve entitlements and update next parameters."""
raise NotImplementedError

async def _retrieve_entitlements_before_strategy(self, retrieve: int) -> list[EntitlementPayload]:
"""Retrieve entitlements using before parameter."""
before = self.before.id if self.before else None
after = self.after.id if self.after else None
data = await self.get_entitlements(
self.state.application_id,
before=before,
after=after,
limit=self.retrieve,
limit=retrieve,
user_id=self.user_id,
guild_id=self.guild_id,
sku_ids=self.sku_ids,
exclude_ended=self.exclude_ended,
)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]["id"]))
return data

if not data:
# no data, terminate
return

if self.limit:
self.limit -= self.retrieve

if len(data) < 100:
self.limit = 0 # terminate loop

self.after = Object(id=int(data[-1]["id"]))

for element in reversed(data):
await self.entitlements.put(Entitlement(data=element, state=self.state))
async def _retrieve_entitlements_after_strategy(self, retrieve: int) -> list[EntitlementPayload]:
"""Retrieve entitlements using after parameter."""
after = self.after.id if self.after else None
data = await self.get_entitlements(
self.state.application_id,
after=after,
limit=retrieve,
user_id=self.user_id,
guild_id=self.guild_id,
sku_ids=self.sku_ids,
exclude_ended=self.exclude_ended,
)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[-1]["id"]))
return data

0 comments on commit 71676ba

Please sign in to comment.