Skip to content

Commit

Permalink
supporting multiple ids and filters together
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak Aggarwal <[email protected]>
  • Loading branch information
sarthakaggarwal97 committed Dec 11, 2024
1 parent 7630eba commit 87108ec
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 39 deletions.
118 changes: 94 additions & 24 deletions src/networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,39 @@
#include <math.h>
#include <ctype.h>
#include <stdatomic.h>
#include "intset.h"

/**
* This struct is used to encapsulate filtering criteria for operations on clients
* such as identifying specific clients to kill or retrieve. Each field in the struct
* represents a filter that can be applied based on specific attributes of a client.
*/
typedef struct {
/** A set of client IDs to filter. If NULL, no ID filtering is applied. */
intset *ids;
/** Maximum age (in seconds) of a client connection for filtering.
* Connections younger than this value will not match.
* A value of 0 means no age filtering. */
long long max_age;
/** Address/port of the client. If NULL, no address filtering is applied. */
char *addr;
/** Remote address/port of the client. If NULL, no address filtering is applied. */
char *laddr;
/** Filtering clients by authentication user. If NULL, no user-based filtering is applied. */
user *user;
/** Client type to filter. If set to -1, no type filtering is applied. */
int type;
/**< Boolean flag to determine if the current client (`me`) should be filtered. 1 means "skip me", 0 means otherwise. */
int skipme;
} clientFilter;

static void setProtocolError(const char *errstr, client *c);
static void pauseClientsByClient(mstime_t end, int isPauseClientAll);
int postponeClientRead(client *c);
char *getClientSockname(client *c);
int parseClientFilters(client *c, int i, clientFilter *filter);
bool clientMatchesFilter(client *client, clientFilter client_filter);
sds getAllFilteredClientsInfoString(clientFilter *client_filter, int hide_user_data);

int ProcessingEventsWhileBlocked = 0; /* See processEventsWhileBlocked(). */
__thread sds thread_shared_qb = NULL;
Expand Down Expand Up @@ -3426,7 +3454,7 @@ sds getAllClientsInfoString(int type, int hide_user_data) {
listNode *ln;
listIter li;
client *client;
sds o = sdsnewlen(SDS_NOINIT, 200 * listLength(server.clients));
sds o = sdsnewlen(SDS_NOINIT, 500);
sdsclear(o);
listRewind(server.clients, &li);
while ((ln = listNext(&li)) != NULL) {
Expand All @@ -3447,13 +3475,7 @@ sds getAllFilteredClientsInfoString(clientFilter *client_filter, int hide_user_d
listRewind(server.clients, &li);
while ((ln = listNext(&li)) != NULL) {
client = listNodeValue(ln);
if (client_filter->addr && strcmp(getClientPeerId(client), client_filter->addr) != 0) continue;
if (client_filter->laddr && strcmp(getClientSockname(client), client_filter->laddr) != 0) continue;
if (client_filter->type != -1 && getClientType(client) != client_filter->type) continue;
if (client_filter->id != 0 && client->id != client_filter->id) continue;
if (client_filter->user && client->user != client_filter->user) continue;
if (client_filter->skipme) continue;
if (client_filter->max_age != 0 && (long long)(commandTimeSnapshot() / 1000 - client->ctime) < client_filter->max_age) continue;
if(!clientMatchesFilter(client, *client_filter)) continue;
o = catClientInfoString(o, client, hide_user_data);
o = sdscatlen(o, "\n", 1);
}
Expand Down Expand Up @@ -3583,13 +3605,27 @@ int parseClientFilters(client *c, int i, clientFilter *filter) {
while (i < c->argc) {
int moreargs = c->argc > i + 1;

if (!strcasecmp(c->argv[i]->ptr, "id") && moreargs) {
long tmp;
if (!strcasecmp(c->argv[i]->ptr, "id")) {
if (filter->ids == NULL) {
filter->ids = intsetNew(); // Initialize the intset for IDs
}
i++; // Move to the first ID after "ID"

if (getRangeLongFromObjectOrReply(c, c->argv[i + 1], 1, LONG_MAX, &tmp,
"client-id should be greater than 0") != C_OK)
return C_ERR;
filter->id = tmp;
// Process all IDs until a non-numeric argument or end of args
while (i < c->argc) {
long long id;
if (!string2ll(c->argv[i]->ptr, sdslen(c->argv[i]->ptr), &id)) {
break; // Stop processing IDs if a non-numeric argument is encountered
}
if (id < 1) {
addReplyError(c, "client-id should be greater than 0");
return C_ERR;
}

uint8_t added;
filter->ids = intsetAdd(filter->ids, id, &added);
i++; // Move to the next argument
}
} else if (!strcasecmp(c->argv[i]->ptr, "maxage") && moreargs) {
long long tmp;

Expand All @@ -3602,22 +3638,27 @@ int parseClientFilters(client *c, int i, clientFilter *filter) {
}

filter->max_age = tmp;
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "type") && moreargs) {
filter->type = getClientTypeByName(c->argv[i + 1]->ptr);
if (filter->type == -1) {
addReplyErrorFormat(c, "Unknown client type '%s'", (char *)c->argv[i + 1]->ptr);
return C_ERR;
}
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "addr") && moreargs) {
filter->addr = c->argv[i + 1]->ptr;
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "laddr") && moreargs) {
filter->laddr = c->argv[i + 1]->ptr;
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "user") && moreargs) {
filter->user = ACLGetUserByName(c->argv[i + 1]->ptr, sdslen(c->argv[i + 1]->ptr));
if (filter->user == NULL) {
addReplyErrorFormat(c, "No such user '%s'", (char *)c->argv[i + 1]->ptr);
return C_ERR;
}
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "skipme") && moreargs) {
if (!strcasecmp(c->argv[i + 1]->ptr, "yes")) {
filter->skipme = 1;
Expand All @@ -3627,15 +3668,29 @@ int parseClientFilters(client *c, int i, clientFilter *filter) {
addReplyErrorObject(c, shared.syntaxerr);
return C_ERR;
}
i += 2;
} else {
addReplyErrorObject(c, shared.syntaxerr);
return C_ERR;
}
i += 2;
}
return C_OK;
}

bool clientMatchesFilter(client *client, clientFilter client_filter) {
// Check each filter condition and return false if the client does not match.
if (client_filter.addr && strcmp(getClientPeerId(client), client_filter.addr) != 0) return false;
if (client_filter.laddr && strcmp(getClientSockname(client), client_filter.laddr) != 0) return false;
if (client_filter.type != -1 && getClientType(client) != client_filter.type) return false;
if (client_filter.ids && !intsetFind(client_filter.ids, client->id)) return false;
if (client_filter.user && client->user != client_filter.user) return false;
if (client_filter.skipme && client == server.current_client) return false; // Skipme check
if (client_filter.max_age != 0 && (long long)(commandTimeSnapshot() / 1000 - client->ctime) < client_filter.max_age) return false;

// If all conditions are satisfied, the client matches the filter.
return true;
}

void clientCommand(client *c) {
listNode *ln;
listIter li;
Expand Down Expand Up @@ -3721,13 +3776,23 @@ void clientCommand(client *c) {
int type = -1;
sds o = NULL;
if (c->argc > 3) {
clientFilter client_filter = {0, 0, NULL, NULL, NULL, -1, 0};
clientFilter client_filter = {.ids = NULL,
.max_age = 0,
.addr = NULL,
.laddr = NULL,
.user = NULL,
.type = -1,
.skipme = 0
};

int i = 2;

if (parseClientFilters(c, i, &client_filter) != C_OK) {
zfree(client_filter.ids);
return;
}
o = getAllFilteredClientsInfoString(&client_filter, 0);
zfree(client_filter.ids);
} else if (c->argc != 2) {
addReplyErrorObject(c, shared.syntaxerr);
return;
Expand Down Expand Up @@ -3767,7 +3832,15 @@ void clientCommand(client *c) {
} else if (!strcasecmp(c->argv[1]->ptr, "kill")) {
/* CLIENT KILL <ip:port>
* CLIENT KILL <option> [value] ... <option> [value] */
clientFilter client_filter = {0, 0, NULL, NULL, NULL, -1, 1};
clientFilter client_filter = {.ids = NULL,
.max_age = 0,
.addr = NULL,
.laddr = NULL,
.user = NULL,
.type = -1,
.skipme = 1
};

int killed = 0, close_this_client = 0;

if (c->argc == 3) {
Expand All @@ -3779,9 +3852,11 @@ void clientCommand(client *c) {

/* New style syntax: parse options. */
if (parseClientFilters(c, i, &client_filter) != C_OK) {
zfree(client_filter.ids); // Free the intset on error
return;
}
} else {
zfree(client_filter.ids); // Free the intset on error
addReplyErrorObject(c, shared.syntaxerr);
return;
}
Expand All @@ -3790,13 +3865,7 @@ void clientCommand(client *c) {
listRewind(server.clients, &li);
while ((ln = listNext(&li)) != NULL) {
client *client = listNodeValue(ln);
if (client_filter.addr && strcmp(getClientPeerId(client), client_filter.addr) != 0) continue;
if (client_filter.laddr && strcmp(getClientSockname(client), client_filter.laddr) != 0) continue;
if (client_filter.type != -1 && getClientType(client) != client_filter.type) continue;
if (client_filter.id != 0 && client->id != client_filter.id) continue;
if (client_filter.user && client->user != client_filter.user) continue;
if (c == client && client_filter.skipme) continue;
if (client_filter.max_age != 0 && (long long)(commandTimeSnapshot() / 1000 - client->ctime) < client_filter.max_age) continue;
if (!clientMatchesFilter(client, client_filter)) continue;

/* Kill it. */
if (c == client) {
Expand All @@ -3820,6 +3889,7 @@ void clientCommand(client *c) {
/* If this client has to be closed, flag it as CLOSE_AFTER_REPLY
* only after we queued the reply to its output buffers. */
if (close_this_client) c->flag.close_after_reply = 1;
zfree(client_filter.ids);
} else if (!strcasecmp(c->argv[1]->ptr, "unblock") && (c->argc == 3 || c->argc == 4)) {
/* CLIENT UNBLOCK <id> [timeout|error] */
long long id;
Expand Down
12 changes: 0 additions & 12 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -1386,16 +1386,6 @@ typedef struct client {
net_output_bytes_curr_cmd; /* Total network output bytes sent to this client, by the current command. */
} client;

typedef struct {
uint64_t id;
long long max_age;
char *addr;
char *laddr;
user *user;
int type;
int skipme;
} clientFilter;

/* When a command generates a lot of discrete elements to the client output buffer, it is much faster to
* skip certain types of initialization. This type is used to indicate a client that has been initialized
* and can be used with addWritePreparedReply* functions. A client can be cast into this type with
Expand Down Expand Up @@ -2878,7 +2868,6 @@ int isClientConnIpV6(client *c);
sds catClientInfoString(sds s, client *client, int hide_user_data);
sds catClientInfoShortString(sds s, client *client, int hide_user_data);
sds getAllClientsInfoString(int type, int hide_user_data);
sds getAllFilteredClientsInfoString(clientFilter *client_filter, int hide_user_data);
int clientSetName(client *c, robj *name, const char **err);
void rewriteClientCommandVector(client *c, int argc, ...);
void rewriteClientCommandArgument(client *c, int i, robj *newval);
Expand Down Expand Up @@ -3957,7 +3946,6 @@ void dumpCommand(client *c);
void objectCommand(client *c);
void memoryCommand(client *c);
void clientCommand(client *c);
int parseClientFilters(client *c, int i, clientFilter *filter);
void helloCommand(client *c);
void clientSetinfoCommand(client *c);
void evalCommand(client *c);
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/introspection.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ start_server {tags {"introspection"}} {
}

test {CLIENT LIST with illegal arguments} {
assert_error "ERR syntax error*" {r client list id 10 wrong_arg}
assert_error "ERR syntax error" {r client list id 10 wrong_arg}

assert_error "ERR *greater than 0*" {r client list id str}
assert_error "ERR syntax error" {r client list id str}
assert_error "ERR *greater than 0*" {r client list id -1}
assert_error "ERR *greater than 0*" {r client list id 0}

Expand Down Expand Up @@ -188,7 +188,7 @@ start_server {tags {"introspection"}} {
assert_error "ERR wrong number of arguments for 'client|kill' command" {r client kill}
assert_error "ERR syntax error*" {r client kill id 10 wrong_arg}

assert_error "ERR *greater than 0*" {r client kill id str}
assert_error "ERR syntax error*" {r client kill id str}
assert_error "ERR *greater than 0*" {r client kill id -1}
assert_error "ERR *greater than 0*" {r client kill id 0}

Expand Down

0 comments on commit 87108ec

Please sign in to comment.