diff --git a/packages/xrpc-server/src/server.ts b/packages/xrpc-server/src/server.ts index d3ca5d7ae19..b0bc8b36a34 100644 --- a/packages/xrpc-server/src/server.ts +++ b/packages/xrpc-server/src/server.ts @@ -67,7 +67,9 @@ export class Server { middleware: Record<'json' | 'text', RequestHandler> globalRateLimiters: RateLimiterI[] sharedRateLimiters: Record - routeRateLimiterFns: Record + // these two are treated separately because we do expensive schema validation after req ratelimits + basicRouteRateLimiterFns: Record // limits based on IP + paramRouteRateLimiterFns: Record // limits based on req context constructor(lexicons?: LexiconDoc[], opts?: Options) { if (lexicons) { @@ -86,7 +88,8 @@ export class Server { } this.globalRateLimiters = [] this.sharedRateLimiters = {} - this.routeRateLimiterFns = {} + this.basicRouteRateLimiterFns = {} + this.paramRouteRateLimiterFns = {} if (opts?.rateLimits?.global) { for (const limit of opts.rateLimits.global) { const rateLimiter = opts.rateLimits.creator({ @@ -250,12 +253,31 @@ export class Server { validateOutput(nsid, def, output, this.lex) const assertValidXrpcParams = (params: unknown) => this.lex.assertValidXrpcParams(nsid, params) - const rlFns = this.routeRateLimiterFns[nsid] ?? [] - const consumeRateLimit = (reqCtx: XRPCReqContext) => - consumeMany(reqCtx, rlFns) + const basicRlFns = this.basicRouteRateLimiterFns[nsid] ?? [] + const consumeBasicRateLimit = ( + req: express.Request, + res: express.Response, + ) => + consumeMany( + { req, res, auth: undefined, input: undefined, params: {} }, + basicRlFns, + ) + const paramRlFns = this.paramRouteRateLimiterFns[nsid] ?? [] + const consumeParamRateLimit = (reqCtx: XRPCReqContext) => + consumeMany(reqCtx, paramRlFns) return async function (req, res, next) { try { + const locals: RequestLocals = req[kRequestLocals] + + // handle req rate limits that don't use validated params + if (basicRlFns.length) { + const result = await consumeBasicRateLimit(req, res) + if (result instanceof RateLimitExceededError) { + return next(result) + } + } + // validate request let params = decodeQueryParams(def, req.query) try { @@ -265,8 +287,6 @@ export class Server { } const input = validateReqInput(req) - const locals: RequestLocals = req[kRequestLocals] - const reqCtx: XRPCReqContext = { params, input, @@ -276,9 +296,11 @@ export class Server { } // handle rate limits - const result = await consumeRateLimit(reqCtx) - if (result instanceof RateLimitExceededError) { - return next(result) + if (paramRlFns.length) { + const result = await consumeParamRateLimit(reqCtx) + if (result instanceof RateLimitExceededError) { + return next(result) + } } // run the handler @@ -422,19 +444,21 @@ export class Server { } private setupRouteRateLimits(nsid: string, config: XRPCHandlerConfig) { - this.routeRateLimiterFns[nsid] = [] + this.basicRouteRateLimiterFns[nsid] = [] + this.paramRouteRateLimiterFns[nsid] = [] for (const limit of this.globalRateLimiters) { const consumeFn = async (ctx: XRPCReqContext) => { return limit.consume(ctx) } - this.routeRateLimiterFns[nsid].push(consumeFn) + this.basicRouteRateLimiterFns[nsid].push(consumeFn) } if (config.rateLimit) { const limits = Array.isArray(config.rateLimit) ? config.rateLimit : [config.rateLimit] - this.routeRateLimiterFns[nsid] = [] + this.basicRouteRateLimiterFns[nsid] ??= [] + this.paramRouteRateLimiterFns[nsid] ??= [] for (let i = 0; i < limits.length; i++) { const limit = limits[i] const { calcKey, calcPoints } = limit @@ -446,7 +470,11 @@ export class Server { calcKey, calcPoints, }) - this.routeRateLimiterFns[nsid].push(consumeFn) + if (calcKey === undefined && calcPoints === undefined) { + this.basicRouteRateLimiterFns[nsid].push(consumeFn) + } else { + this.paramRouteRateLimiterFns[nsid].push(consumeFn) + } } } else { const { durationMs, points } = limit @@ -464,7 +492,11 @@ export class Server { calcKey, calcPoints, }) - this.routeRateLimiterFns[nsid].push(consumeFn) + if (calcKey === undefined && calcPoints === undefined) { + this.basicRouteRateLimiterFns[nsid].push(consumeFn) + } else { + this.paramRouteRateLimiterFns[nsid].push(consumeFn) + } } } }