1
0
mirror of https://github.com/spacebarchat/server.git synced 2024-11-11 05:02:37 +01:00

fix rate limit

This commit is contained in:
Flam3rboy 2021-08-29 16:58:23 +02:00
parent 38162b1c20
commit 7674149085
2 changed files with 52 additions and 48 deletions

View File

@ -4,7 +4,7 @@ import { FieldError } from "../util/instanceOf";
// TODO: update with new body/typorm validation // TODO: update with new body/typorm validation
export function ErrorHandler(error: Error, req: Request, res: Response, next: NextFunction) { export function ErrorHandler(error: Error, req: Request, res: Response, next: NextFunction) {
if (!error) next(); if (!error) return next();
try { try {
let code = 400; let code = 400;
@ -18,7 +18,6 @@ export function ErrorHandler(error: Error, req: Request, res: Response, next: Ne
message = error.message; message = error.message;
errors = error.errors; errors = error.errors;
} else { } else {
console.error(error);
if (req.server?.options?.production) { if (req.server?.options?.production) {
message = "Internal Server Error"; message = "Internal Server Error";
} }
@ -27,7 +26,7 @@ export function ErrorHandler(error: Error, req: Request, res: Response, next: Ne
if (httpcode > 511) httpcode = 400; if (httpcode > 511) httpcode = 400;
console.error(`[Error] ${code} ${req.url} ${message}`, errors || error); console.error(`[Error] ${code} ${req.url}`, errors || error, "body:", req.body);
res.status(httpcode).json({ code: code, message, errors }); res.status(httpcode).json({ code: code, message, errors });
} catch (error) { } catch (error) {

View File

@ -1,6 +1,6 @@
// @ts-nocheck import { Config, listenEvent, emitEvent, RateLimit } from "@fosscord/util";
import { db, Bucket, Config, listenEvent, emitEvent } from "@fosscord/util";
import { NextFunction, Request, Response, Router } from "express"; import { NextFunction, Request, Response, Router } from "express";
import { LessThan } from "typeorm";
import { getIpAdress } from "../util/ipAddress"; import { getIpAdress } from "../util/ipAddress";
import { API_PREFIX_TRAILING_SLASH } from "./Authentication"; import { API_PREFIX_TRAILING_SLASH } from "./Authentication";
@ -18,10 +18,10 @@ TODO: different for methods (GET/POST)
*/ */
var Cache = new Map<string, Bucket>(); var Cache = new Map<string, RateLimit>();
const EventRateLimit = "ratelimit"; const EventRateLimit = "RATELIMIT";
export default function RateLimit(opts: { export default function rateLimit(opts: {
bucket?: string; bucket?: string;
window: number; window: number;
count: number; count: number;
@ -36,15 +36,15 @@ export default function RateLimit(opts: {
}): any { }): any {
return async (req: Request, res: Response, next: NextFunction): Promise<any> => { return async (req: Request, res: Response, next: NextFunction): Promise<any> => {
const bucket_id = opts.bucket || req.originalUrl.replace(API_PREFIX_TRAILING_SLASH, ""); const bucket_id = opts.bucket || req.originalUrl.replace(API_PREFIX_TRAILING_SLASH, "");
var user_id = getIpAdress(req); var executor_id = getIpAdress(req);
if (!opts.onlyIp && req.user_id) user_id = req.user_id; if (!opts.onlyIp && req.user_id) executor_id = req.user_id;
var max_hits = opts.count; var max_hits = opts.count;
if (opts.bot && req.user_bot) max_hits = opts.bot; if (opts.bot && req.user_bot) max_hits = opts.bot;
if (opts.GET && ["GET", "OPTIONS", "HEAD"].includes(req.method)) max_hits = opts.GET; if (opts.GET && ["GET", "OPTIONS", "HEAD"].includes(req.method)) max_hits = opts.GET;
else if (opts.MODIFY && ["POST", "DELETE", "PATCH", "PUT"].includes(req.method)) max_hits = opts.MODIFY; else if (opts.MODIFY && ["POST", "DELETE", "PATCH", "PUT"].includes(req.method)) max_hits = opts.MODIFY;
const offender = Cache.get(user_id + bucket_id) as Bucket | null; const offender = Cache.get(executor_id + bucket_id);
if (offender && offender.blocked) { if (offender && offender.blocked) {
const reset = offender.expires_at.getTime(); const reset = offender.expires_at.getTime();
@ -72,12 +72,12 @@ export default function RateLimit(opts: {
offender.expires_at = new Date(Date.now() + opts.window * 1000); offender.expires_at = new Date(Date.now() + opts.window * 1000);
offender.blocked = false; offender.blocked = false;
// mongodb ttl didn't update yet -> manually update/delete // mongodb ttl didn't update yet -> manually update/delete
db.collection("ratelimits").update({ id: bucket_id, user_id }, { $set: offender }); RateLimit.delete({ id: bucket_id, executor_id });
Cache.delete(user_id + bucket_id); Cache.delete(executor_id + bucket_id);
} }
} }
next(); next();
const hitRouteOpts = { bucket_id, user_id, max_hits, window: opts.window }; const hitRouteOpts = { bucket_id, executor_id, max_hits, window: opts.window };
if (opts.error || opts.success) { if (opts.error || opts.success) {
res.once("finish", () => { res.once("finish", () => {
@ -97,69 +97,74 @@ export default function RateLimit(opts: {
export async function initRateLimits(app: Router) { export async function initRateLimits(app: Router) {
const { routes, global, ip, error } = Config.get().limits.rate; const { routes, global, ip, error } = Config.get().limits.rate;
await listenEvent(EventRateLimit, (event) => { await listenEvent(EventRateLimit, (event) => {
Cache.set(event.channel_id, event.data); Cache.set(event.channel_id as string, event.data);
event.acknowledge?.(); event.acknowledge?.();
}); });
await RateLimit.delete({ expires_at: LessThan(new Date()) }); // clean up if not already deleted
const limits = await RateLimit.find({ blocked: true });
limits.forEach((limit) => {
Cache.set(limit.executor_id, limit);
});
setInterval(() => { setInterval(() => {
Cache.forEach((x, key) => { Cache.forEach((x, key) => {
if (Date.now() > x.expires_at) Cache.delete(key); if (new Date() > x.expires_at) {
Cache.delete(key);
RateLimit.delete({ executor_id: key });
}
}); });
}, 1000 * 60 * 10); }, 1000 * 60 * 10);
app.use( app.use(
RateLimit({ rateLimit({
bucket: "global", bucket: "global",
onlyIp: true, onlyIp: true,
...ip ...ip
}) })
); );
app.use(RateLimit({ bucket: "global", ...global })); app.use(rateLimit({ bucket: "global", ...global }));
app.use( app.use(
RateLimit({ rateLimit({
bucket: "error", bucket: "error",
error: true, error: true,
onlyIp: true, onlyIp: true,
...error ...error
}) })
); );
app.use("/guilds/:id", RateLimit(routes.guild)); app.use("/guilds/:id", rateLimit(routes.guild));
app.use("/webhooks/:id", RateLimit(routes.webhook)); app.use("/webhooks/:id", rateLimit(routes.webhook));
app.use("/channels/:id", RateLimit(routes.channel)); app.use("/channels/:id", rateLimit(routes.channel));
app.use("/auth/login", RateLimit(routes.auth.login)); app.use("/auth/login", rateLimit(routes.auth.login));
app.use("/auth/register", RateLimit({ onlyIp: true, success: true, ...routes.auth.register })); app.use("/auth/register", rateLimit({ onlyIp: true, success: true, ...routes.auth.register }));
} }
async function hitRoute(opts: { user_id: string; bucket_id: string; max_hits: number; window: number }) { async function hitRoute(opts: { executor_id: string; bucket_id: string; max_hits: number; window: number }) {
const filter = { id: opts.bucket_id, user_id: opts.user_id }; var ratelimit = await RateLimit.findOne({ id: opts.bucket_id, executor_id: opts.executor_id });
const { value } = await db.collection("ratelimits").findOneOrFailAndUpdate( if (!ratelimit) {
filter, ratelimit = new RateLimit({
{ id: opts.bucket_id,
$setOnInsert: { executor_id: opts.executor_id,
id: opts.bucket_id, expires_at: new Date(Date.now() + opts.window * 1000),
user_id: opts.user_id, hits: 0,
expires_at: new Date(Date.now() + opts.window * 1000) blocked: false
}, });
$inc: { }
hits: 1
} ratelimit.hits++;
// Conditionally update blocked doesn't work
}, const updateBlock = !ratelimit.blocked && ratelimit.hits >= opts.max_hits;
{ upsert: true, returnDocument: "before" }
);
if (!value) return;
const updateBlock = !value.blocked && value.hits >= opts.max_hits;
if (updateBlock) { if (updateBlock) {
value.blocked = true; ratelimit.blocked = true;
Cache.set(opts.user_id + opts.bucket_id, value); Cache.set(opts.executor_id + opts.bucket_id, ratelimit);
await emitEvent({ await emitEvent({
channel_id: EventRateLimit, channel_id: EventRateLimit,
event: EventRateLimit, event: EventRateLimit,
data: value data: ratelimit
}); });
await db.collection("ratelimits").update(filter, { $set: { blocked: true } });
} else { } else {
Cache.delete(opts.user_id); Cache.delete(opts.executor_id);
} }
await ratelimit.save();
} }