Skip to content

Commit

Permalink
supporting more filters for Client List and Client Kill
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak Aggarwal <[email protected]>
  • Loading branch information
sarthakaggarwal97 committed Dec 20, 2024
1 parent 52fd611 commit 6819a68
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 10 deletions.
224 changes: 217 additions & 7 deletions src/networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ typedef struct {
int type;
/* Boolean flag to determine if the current client (`me`) should be filtered. 1 means "skip me", 0 means otherwise. */
int skipme;
/* Client name to filter. If NULL, no name filtering is applied. */
char *name;
/* Minimum idle time (in seconds) of a client connection for filtering.
* Connections with idle time more than this value will match.
* A value of 0 means no idle time filtering. */
long long min_idle;
/* Client flags for filtering. If NULL, no filtering is applied. */
char *flags;
/* Client pattern for filtering. If NULL, no filtering is applied. */
robj *pattern;
/* Client channel for filtering. If NULL, no filtering is applied. */
robj *channel;
/* Client shard channel for filtering. If NULL, no filtering is applied. */
robj *shard_channel;
} clientFilter;

static void clientCommandHelp(client *c);
Expand Down Expand Up @@ -91,6 +105,11 @@ char *getClientSockname(client *c);
static int parseClientFiltersOrReply(client *c, int i, clientFilter *filter);
static int clientMatchesFilter(client *client, clientFilter client_filter);
sds getAllFilteredClientsInfoString(clientFilter *client_filter, int hide_user_data);
static int clientMatchesFlagFilter(client *c, const char *flag_filter);
static int clientSubscribedToChannel(client *client, robj *channel);
static int clientSubscribedToShardChannel(client *client, robj *channel);
static int clientSubscribedToPattern(client *client, robj *pattern);
static void freeClientFilter(clientFilter *filter);

int ProcessingEventsWhileBlocked = 0; /* See processEventsWhileBlocked(). */
__thread sds thread_shared_qb = NULL;
Expand Down Expand Up @@ -3687,6 +3706,34 @@ static int parseClientFiltersOrReply(client *c, int i, clientFilter *filter) {
return C_ERR;
}
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "minidle") && moreargs) {
long long tmp;

if (getLongLongFromObjectOrReply(c, c->argv[i + 1], &tmp,
"minidle is not an integer or out of range") != C_OK)
return C_ERR;
if (tmp <= 0) {
addReplyError(c, "minidle should be greater than 0");
return C_ERR;
}

filter->min_idle = tmp;
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "flags") && moreargs) {
filter->flags = c->argv[i + 1]->ptr;
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "name") && moreargs) {
filter->name = c->argv[i + 1]->ptr;
i += 2;
} else if (!strcasecmp(c->argv[i]->ptr, "pattern") && moreargs) {
filter->pattern = createObject(OBJ_STRING, sdsnew(c->argv[i + 1]->ptr));
i += 2;
}else if (!strcasecmp(c->argv[i]->ptr, "channel") && moreargs) {
filter->channel = createObject(OBJ_STRING, sdsnew(c->argv[i + 1]->ptr));
i += 2;
}else if (!strcasecmp(c->argv[i]->ptr, "shardchannel") && moreargs) {
filter->shard_channel = createObject(OBJ_STRING, sdsnew(c->argv[i + 1]->ptr));
i += 2;
} else {
addReplyErrorObject(c, shared.syntaxerr);
return C_ERR;
Expand All @@ -3704,11 +3751,126 @@ static int clientMatchesFilter(client *client, clientFilter client_filter) {
if (client_filter.user && client->user != client_filter.user) return 0;
if (client_filter.skipme && client == server.current_client) return 0; // Skipme check
if (client_filter.max_age != 0 && (long long)(commandTimeSnapshot() / 1000 - client->ctime) < client_filter.max_age) return 0;
if (client_filter.min_idle != 0 && (long long)(commandTimeSnapshot() / 1000 - client->last_interaction) < client_filter.min_idle) return 0;
if (client_filter.flags && clientMatchesFlagFilter(client, client_filter.flags) == 0) return 0;
if (client_filter.name) {
if (!client->name || !client->name->ptr || strcmp(client->name->ptr, client_filter.name) != 0) {
return 0;
}
}
if (client_filter.pattern && !clientSubscribedToPattern(client, client_filter.pattern)) return 0;
if (client_filter.channel && !clientSubscribedToChannel(client, client_filter.channel)) return 0;
if (client_filter.shard_channel && !clientSubscribedToShardChannel(client, client_filter.shard_channel)) return 0;

// If all conditions are satisfied, the client matches the filter.
return 1;
}

/* Function to check if the client has the required flags as per the filter string */
static int clientMatchesFlagFilter(client *c, const char *flag_filter) {
// Iterate through the provided flag filter string
for (int i = 0; flag_filter[i] != '\0'; i++) {
const char flag = flag_filter[i];

// Check each flag
switch (flag) {
case 'O':
if (!(c->flag.replica && c->flag.monitor)) return 0;
break;
case 'S': // Replica flag
if (!c->flag.replica) return 0;
break;
case 'M': // Primary flag
if (!c->flag.primary) return 0;
break;
case 'P': // PubSub flag
if (!c->flag.pubsub) return 0;
break;
case 'x': // Multi flag
if (!c->flag.multi) return 0;
break;
case 'b': // Blocked flag
if (!c->flag.blocked) return 0;
break;
case 't': // Tracking flag
if (!c->flag.tracking) return 0;
break;
case 'R': // Invalid Client flag
if (!c->flag.tracking_broken_redir) return 0;
break;
case 'B': // Tracking Bcast flag
if (!c->flag.tracking_bcast) return 0;
break;
case 'd': // Dirty CAS flag
if (!c->flag.dirty_cas) return 0;
break;
case 'c': // Close after reply flag
if (!c->flag.close_after_reply) return 0;
break;
case 'u': // Unblocked flag
if (!c->flag.unblocked) return 0;
break;
case 'A': // Close ASAP flag
if (!c->flag.close_asap) return 0;
break;
case 'U': // Unix socket flag
if (!c->flag.unix_socket) return 0;
break;
case 'r': // Readonly flag
if (!c->flag.readonly) return 0;
break;
case 'e': // No evict flag
if (!c->flag.no_evict) return 0;
break;
case 'T': // No touch flag
if (!c->flag.no_touch) return 0;
break;
case 'I': // Import source flag
if (!c->flag.import_source) return 0;
break;
case 'N': // Check for no flags
if (!c->flag.replica && !c->flag.primary && !c->flag.pubsub &&
!c->flag.multi && !c->flag.blocked && !c->flag.tracking &&
!c->flag.tracking_broken_redir && !c->flag.tracking_bcast &&
!c->flag.dirty_cas && !c->flag.close_after_reply &&
!c->flag.unblocked && !c->flag.close_asap &&
!c->flag.unix_socket && !c->flag.readonly &&
!c->flag.no_evict && !c->flag.no_touch &&
!c->flag.import_source) {
return 1; // Matches 'N'
}
break;
default:
// Invalid flag, return false
return 0;
}
}
// If the loop completes, the client matches the flag filter
return 1;
}

static int clientSubscribedToChannel(client *client, robj *channel) {
if (client == NULL || client->pubsub_channels == NULL) {
return 0;
}
return dictFind(client->pubsub_channels, channel) != NULL;
}

static int clientSubscribedToShardChannel(client *client, robj *channel) {
if (client == NULL || client->pubsubshard_channels == NULL) {
return 0;
}
return dictFind(client->pubsubshard_channels, channel) != NULL;
}

static int clientSubscribedToPattern(client *client, robj *pattern) {
if (client == NULL || client->pubsub_patterns == NULL) {
return 0;
}
return dictFind(client->pubsub_patterns, pattern) != NULL;
}


static void clientCommandHelp(client *c) {
const char *help[] = {
"CACHING (YES|NO)",
Expand All @@ -3730,23 +3892,55 @@ static void clientCommandHelp(client *c) {
"KILL <option> <value> [<option> <value> [...]]",
" Kill connections. Options are:",
" * ADDR (<ip:port>|<unixsocket>:0)",
" Kill connections made from the specified address",
" Kill connections made from the specified address.",
" * LADDR (<ip:port>|<unixsocket>:0)",
" Kill connections made to specified local address",
" Kill connections made to the specified local address.",
" * TYPE (NORMAL|PRIMARY|REPLICA|PUBSUB)",
" Kill connections by type.",
" * USER <username>",
" Kill connections authenticated by <username>.",
" * SKIPME (YES|NO)",
" Skip killing current connection (default: yes).",
" Skip killing the current connection (default: yes).",
" * ID <client-id>",
" Kill connections by client id.",
" Kill connections by client ID.",
" * MAXAGE <maxage>",
" Kill connections older than the specified age.",
" * FLAGS <flags>",
" Kill connections that include the specified flags.",
" * NAME <client-name>",
" Kill connections with the specified name.",
" * PATTERN <pattern>",
" Kill connections subscribed to a matching pattern.",
" * CHANNEL <channel>",
" Kill connections subscribed to a matching channel.",
" * SHARD-CHANNEL <shard-channel>",
" Kill connections subscribed to a matching shard channel.",
"LIST [options ...]",
" Return information about client connections. Options:",
" * TYPE (NORMAL|PRIMARY|REPLICA|PUBSUB)",
" Return clients of specified type.",
" * USER <username>",
" Return clients authenticated by <username>.",
" * ADDR <ip:port>",
" Return clients connected from the specified address.",
" * LADDR <ip:port>",
" Return clients connected to the specified local address.",
" * ID <client-id>",
" Return clients with the specified IDs.",
" * SKIPME (YES|NO)",
" Exclude the current client from the list (default: no).",
" * FLAGS <flags>",
" Return clients with the specified flags.",
" * NAME <client-name>",
" Return clients with the specified name.",
" * MIN-IDLE <min-idle>",
" Return clients with idle time greater than or equal to <min-idle> seconds.",
" * PATTERN <pattern>",
" Return clients subscribed to a matching pattern.",
" * CHANNEL <channel>",
" Return clients subscribed to the specified channel.",
" * SHARD-CHANNEL <shard-channel>",
" Return clients subscribed to the specified shard channel.",
"UNPAUSE",
" Stop the current client pause, resuming traffic.",
"PAUSE <timeout> [WRITE|ALL]",
Expand Down Expand Up @@ -3798,11 +3992,11 @@ static void clientCommandList(client *c) {
int i = 2;

if (parseClientFiltersOrReply(c, i, &filter) != C_OK) {
zfree(filter.ids);
freeClientFilter(&filter);
return;
}
response = getAllFilteredClientsInfoString(&filter, 0);
zfree(filter.ids);
freeClientFilter(&filter);
} else if (c->argc != 2) {
addReplyErrorObject(c, shared.syntaxerr);
return;
Expand Down Expand Up @@ -3909,7 +4103,23 @@ static void clientCommandKill(client *c) {
* only after we queued the reply to its output buffers. */
if (close_this_client) c->flag.close_after_reply = 1;
client_kill_done:
zfree(client_filter.ids);
freeClientFilter(&client_filter);
}

static void freeClientFilter(clientFilter *filter) {
zfree(filter->ids);
if (filter->pattern) {
decrRefCount(filter->pattern);
filter->pattern = NULL;
}
if (filter->shard_channel) {
decrRefCount(filter->shard_channel);
filter->shard_channel = NULL;
}
if (filter->channel) {
decrRefCount(filter->channel);
filter->channel = NULL;
}
}


Expand Down
Loading

0 comments on commit 6819a68

Please sign in to comment.