From f1f9e51cf689fb94fec629b9081f772de4659966 Mon Sep 17 00:00:00 2001 From: Vijay Nirmal Date: Tue, 17 Dec 2024 05:09:13 +0530 Subject: [PATCH] [Compatibility] Added ZINTER, ZINTERCARD, ZINTERSTORE command (#831) * Added ZINTER, ZINTERCARD, ZINTERSTORE command * Format fix * Test fix * Added comments and docs * Fix magic string * Review commands * Fixed review comments --------- Co-authored-by: Tal Zaccai --- libs/resources/RespCommandsDocs.json | 168 ++++++++++ libs/resources/RespCommandsInfo.json | 82 +++++ libs/server/API/GarnetApiObjectCommands.cs | 12 + libs/server/API/GarnetWatchApi.cs | 20 ++ libs/server/API/IGarnetApi.cs | 30 ++ libs/server/Resp/CmdStrings.cs | 7 +- libs/server/Resp/Objects/SortedSetCommands.cs | 277 +++++++++++++++- libs/server/Resp/Parser/RespCommand.cs | 15 + libs/server/Resp/RespServerSession.cs | 3 + libs/server/SessionParseStateExtensions.cs | 23 ++ libs/server/SortedSetAggregateType.cs | 24 ++ .../Session/ObjectStore/SortedSetOps.cs | 197 +++++++++++ .../CommandInfoUpdater/SupportedCommand.cs | 3 + .../RedirectTests/BaseCommand.cs | 87 +++++ .../ClusterSlotVerificationTests.cs | 21 ++ test/Garnet.test/Resp/ACL/RespCommandTests.cs | 45 +++ test/Garnet.test/RespSortedSetTests.cs | 312 ++++++++++++++++++ website/docs/commands/api-compatibility.md | 6 +- website/docs/commands/data-structures.md | 38 ++- 19 files changed, 1364 insertions(+), 6 deletions(-) create mode 100644 libs/server/SortedSetAggregateType.cs diff --git a/libs/resources/RespCommandsDocs.json b/libs/resources/RespCommandsDocs.json index 10fc7b6e57..d9e9fa0b4f 100644 --- a/libs/resources/RespCommandsDocs.json +++ b/libs/resources/RespCommandsDocs.json @@ -5817,6 +5817,174 @@ } ] }, + { + "Command": "ZINTER", + "Name": "ZINTER", + "Summary": "Returns the intersect of multiple sorted sets.", + "Group": "SortedSet", + "Complexity": "O(N*K)\u002BO(M*log(M)) worst case with N being the smallest input sorted set, K being the number of input sorted sets and M being the number of elements in the resulting sorted set.", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "NUMKEYS", + "DisplayText": "numkeys", + "Type": "Integer" + }, + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "ArgumentFlags": "Multiple", + "KeySpecIndex": 0 + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "WEIGHT", + "DisplayText": "weight", + "Type": "Integer", + "Token": "WEIGHTS", + "ArgumentFlags": "Optional, Multiple" + }, + { + "TypeDiscriminator": "RespCommandContainerArgument", + "Name": "AGGREGATE", + "Type": "OneOf", + "Token": "AGGREGATE", + "ArgumentFlags": "Optional", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "SUM", + "DisplayText": "sum", + "Type": "PureToken", + "Token": "SUM" + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "MIN", + "DisplayText": "min", + "Type": "PureToken", + "Token": "MIN" + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "MAX", + "DisplayText": "max", + "Type": "PureToken", + "Token": "MAX" + } + ] + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "WITHSCORES", + "DisplayText": "withscores", + "Type": "PureToken", + "Token": "WITHSCORES", + "ArgumentFlags": "Optional" + } + ] + }, + { + "Command": "ZINTERCARD", + "Name": "ZINTERCARD", + "Summary": "Returns the number of members of the intersect of multiple sorted sets.", + "Group": "SortedSet", + "Complexity": "O(N*K) worst case with N being the smallest input sorted set, K being the number of input sorted sets.", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "NUMKEYS", + "DisplayText": "numkeys", + "Type": "Integer" + }, + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "ArgumentFlags": "Multiple", + "KeySpecIndex": 0 + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "LIMIT", + "DisplayText": "limit", + "Type": "Integer", + "Token": "LIMIT", + "ArgumentFlags": "Optional" + } + ] + }, + { + "Command": "ZINTERSTORE", + "Name": "ZINTERSTORE", + "Summary": "Stores the intersect of multiple sorted sets in a key.", + "Group": "SortedSet", + "Complexity": "O(N*K)\u002BO(M*log(M)) worst case with N being the smallest input sorted set, K being the number of input sorted sets and M being the number of elements in the resulting sorted set.", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "DESTINATION", + "DisplayText": "destination", + "Type": "Key", + "KeySpecIndex": 0 + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "NUMKEYS", + "DisplayText": "numkeys", + "Type": "Integer" + }, + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "ArgumentFlags": "Multiple", + "KeySpecIndex": 1 + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "WEIGHT", + "DisplayText": "weight", + "Type": "Integer", + "Token": "WEIGHTS", + "ArgumentFlags": "Optional, Multiple" + }, + { + "TypeDiscriminator": "RespCommandContainerArgument", + "Name": "AGGREGATE", + "Type": "OneOf", + "Token": "AGGREGATE", + "ArgumentFlags": "Optional", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "SUM", + "DisplayText": "sum", + "Type": "PureToken", + "Token": "SUM" + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "MIN", + "DisplayText": "min", + "Type": "PureToken", + "Token": "MIN" + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "MAX", + "DisplayText": "max", + "Type": "PureToken", + "Token": "MAX" + } + ] + } + ] + }, { "Command": "ZLEXCOUNT", "Name": "ZLEXCOUNT", diff --git a/libs/resources/RespCommandsInfo.json b/libs/resources/RespCommandsInfo.json index 4c06e3e447..15307b1e9d 100644 --- a/libs/resources/RespCommandsInfo.json +++ b/libs/resources/RespCommandsInfo.json @@ -4284,6 +4284,88 @@ } ] }, + { + "Command": "ZINTER", + "Name": "ZINTER", + "Arity": -3, + "Flags": "MovableKeys, ReadOnly", + "AclCategories": "Read, SortedSet, Slow", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysKeyNum", + "KeyNumIdx": 0, + "FirstKey": 1, + "KeyStep": 1 + }, + "Flags": "RO, Access" + } + ] + }, + { + "Command": "ZINTERCARD", + "Name": "ZINTERCARD", + "Arity": -3, + "Flags": "MovableKeys, ReadOnly", + "AclCategories": "Read, SortedSet, Slow", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysKeyNum", + "KeyNumIdx": 0, + "FirstKey": 1, + "KeyStep": 1 + }, + "Flags": "RO, Access" + } + ] + }, + { + "Command": "ZINTERSTORE", + "Name": "ZINTERSTORE", + "Arity": -4, + "Flags": "DenyOom, MovableKeys, Write", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "SortedSet, Slow, Write", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "OW, Update" + }, + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 2 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysKeyNum", + "KeyNumIdx": 0, + "FirstKey": 1, + "KeyStep": 1 + }, + "Flags": "RO, Access" + } + ] + }, { "Command": "ZLEXCOUNT", "Name": "ZLEXCOUNT", diff --git a/libs/server/API/GarnetApiObjectCommands.cs b/libs/server/API/GarnetApiObjectCommands.cs index 670db210ee..8fd44473dd 100644 --- a/libs/server/API/GarnetApiObjectCommands.cs +++ b/libs/server/API/GarnetApiObjectCommands.cs @@ -146,6 +146,18 @@ public GarnetStatus SortedSetDifferenceStore(ArgSlice destinationKey, ReadOnlySp public GarnetStatus SortedSetScan(ArgSlice key, long cursor, string match, int count, out ArgSlice[] items) => storageSession.ObjectScan(GarnetObjectType.SortedSet, key, cursor, match, count, out items, ref objectContext); + /// + public GarnetStatus SortedSetIntersect(ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out Dictionary pairs) + => storageSession.SortedSetIntersect(keys, weights, aggregateType, out pairs); + + /// + public GarnetStatus SortedSetIntersectLength(ReadOnlySpan keys, int? limit, out int count) + => storageSession.SortedSetIntersectLength(keys, limit, out count); + + /// + public GarnetStatus SortedSetIntersectStore(ArgSlice destinationKey, ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out int count) + => storageSession.SortedSetIntersectStore(destinationKey, keys, weights, aggregateType, out count); + #endregion #region Geospatial commands diff --git a/libs/server/API/GarnetWatchApi.cs b/libs/server/API/GarnetWatchApi.cs index 87f506721f..1b359c3d46 100644 --- a/libs/server/API/GarnetWatchApi.cs +++ b/libs/server/API/GarnetWatchApi.cs @@ -198,6 +198,26 @@ public GarnetStatus SortedSetScan(ArgSlice key, long cursor, string match, int c return garnetApi.SortedSetScan(key, cursor, match, count, out items); } + /// + public GarnetStatus SortedSetIntersect(ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out Dictionary pairs) + { + foreach (var key in keys) + { + garnetApi.WATCH(key, StoreType.Object); + } + return garnetApi.SortedSetIntersect(keys, weights, aggregateType, out pairs); + } + + /// + public GarnetStatus SortedSetIntersectLength(ReadOnlySpan keys, int? limit, out int count) + { + foreach (var key in keys) + { + garnetApi.WATCH(key, StoreType.Object); + } + return garnetApi.SortedSetIntersectLength(keys, limit, out count); + } + #endregion #region List Methods diff --git a/libs/server/API/IGarnetApi.cs b/libs/server/API/IGarnetApi.cs index 15d9c2255a..0ae7353622 100644 --- a/libs/server/API/IGarnetApi.cs +++ b/libs/server/API/IGarnetApi.cs @@ -521,6 +521,17 @@ public interface IGarnetApi : IGarnetReadApi, IGarnetAdvancedApi /// GarnetStatus GeoSearchStore(ArgSlice key, ArgSlice destinationKey, ref ObjectInput input, ref SpanByteAndMemory output); + /// + /// Intersects multiple sorted sets and stores the result in the destination key. + /// + /// The key where the result will be stored. + /// The keys of the sorted sets to intersect. + /// The weights to apply to each sorted set during the intersection. + /// The type of aggregation to use for the intersection. + /// The number of elements in the resulting sorted set. + /// A indicating the status of the operation. + GarnetStatus SortedSetIntersectStore(ArgSlice destinationKey, ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out int count); + #endregion #region Set Methods @@ -1295,6 +1306,25 @@ public interface IGarnetReadApi /// GarnetStatus SortedSetScan(ArgSlice key, long cursor, string match, int count, out ArgSlice[] items); + /// + /// Intersects multiple sorted sets and returns the result. + /// + /// The keys of the sorted sets to intersect. + /// The weights to apply to each sorted set. + /// The type of aggregation to perform. + /// The resulting dictionary of intersected elements and their scores. + /// A indicating the status of the operation. + GarnetStatus SortedSetIntersect(ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out Dictionary pairs); + + /// + /// Computes the intersection of multiple sorted sets and counts the elements. + /// + /// Input sorted set keys + /// Optional max count limit + /// The count of elements in the intersection + /// Operation status + GarnetStatus SortedSetIntersectLength(ReadOnlySpan keys, int? limit, out int count); + #endregion #region Geospatial Methods diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index aad5e5e43f..82634932f6 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -119,9 +119,12 @@ static partial class CmdStrings public static ReadOnlySpan LEFT => "LEFT"u8; public static ReadOnlySpan BYLEX => "BYLEX"u8; public static ReadOnlySpan REV => "REV"u8; - public static ReadOnlySpan LIMIT => "LIMIT"u8; + public static ReadOnlySpan WEIGHTS => "WEIGHTS"u8; + public static ReadOnlySpan AGGREGATE => "AGGREGATE"u8; + public static ReadOnlySpan SUM => "SUM"u8; public static ReadOnlySpan MIN => "MIN"u8; public static ReadOnlySpan MAX => "MAX"u8; + public static ReadOnlySpan LIMIT => "LIMIT"u8; /// /// Response strings @@ -224,7 +227,9 @@ static partial class CmdStrings "ERR Invalid number of parameters to stored proc {0}, expected {1}, actual {2}"; public const string GenericSyntaxErrorOption = "ERR Syntax error in {0} option '{1}'"; public const string GenericParamShouldBeGreaterThanZero = "ERR {0} should be greater than 0"; + public const string GenericErrNotAFloat = "ERR {0} value is not a valid float"; public const string GenericErrCantBeNegative = "ERR {0} can't be negative"; + public const string GenericErrAtLeastOneKey = "ERR at least 1 input key is needed for '{0}' command"; public const string GenericErrShouldBeGreaterThanZero = "ERR {0} should be greater than 0"; public const string GenericUnknownClientType = "ERR Unknown client type '{0}'"; public const string GenericErrDuplicateFilter = "ERR Filter '{0}' defined multiple times"; diff --git a/libs/server/Resp/Objects/SortedSetCommands.cs b/libs/server/Resp/Objects/SortedSetCommands.cs index 6289f23d9e..6cdd92d5b0 100644 --- a/libs/server/Resp/Objects/SortedSetCommands.cs +++ b/libs/server/Resp/Objects/SortedSetCommands.cs @@ -933,7 +933,7 @@ private unsafe bool SortedSetDifference(ref TGarnetApi storageApi) { var withScores = parseState.GetArgSliceByRef(parseState.Count - 1).ReadOnlySpan; - if (!withScores.SequenceEqual(CmdStrings.WITHSCORES)) + if (!withScores.EqualsUpperCaseSpanIgnoringCase(CmdStrings.WITHSCORES)) { while (!RespWriteUtils.WriteError(CmdStrings.RESP_SYNTAX_ERROR, ref dcurr, dend)) SendAndReset(); @@ -1028,5 +1028,280 @@ private unsafe bool SortedSetDifferenceStore(ref TGarnetApi storageA return true; } + + /// + /// Computes an intersection operation between multiple sorted sets + /// and returns the result to the client. + /// + /// + /// + private unsafe bool SortedSetIntersect(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (parseState.Count < 2) + { + return AbortWithWrongNumberOfArguments(nameof(RespCommand.ZINTER)); + } + + // Number of keys + if (!parseState.TryGetInt(0, out var nKeys)) + { + return AbortWithErrorMessage(CmdStrings.RESP_ERR_GENERIC_VALUE_IS_NOT_INTEGER); + } + + if (nKeys < 1) + { + return AbortWithErrorMessage(Encoding.ASCII.GetBytes(string.Format(CmdStrings.GenericErrAtLeastOneKey, nameof(RespCommand.ZINTER)))); + } + + if (parseState.Count < nKeys + 1) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + + var includeWithScores = false; + double[] weights = null; + var aggregateType = SortedSetAggregateType.Sum; + var currentArg = nKeys + 1; + + // Read all the keys + var keys = parseState.Parameters.Slice(1, nKeys); + + // Parse optional arguments + while (currentArg < parseState.Count) + { + var arg = parseState.GetArgSliceByRef(currentArg).ReadOnlySpan; + + if (arg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.WITHSCORES)) + { + includeWithScores = true; + currentArg++; + } + else if (arg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.WEIGHTS)) + { + currentArg++; + if (currentArg + nKeys > parseState.Count) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + + weights = new double[nKeys]; + for (var i = 0; i < nKeys; i++) + { + if (!parseState.TryGetDouble(currentArg + i, out weights[i])) + { + return AbortWithErrorMessage(Encoding.ASCII.GetBytes(string.Format(CmdStrings.GenericErrNotAFloat, "weight"))); + } + } + currentArg += nKeys; + } + else if (arg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.AGGREGATE)) + { + if (++currentArg >= parseState.Count || !parseState.TryGetSortedSetAggregateType(currentArg++, out aggregateType)) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + } + else + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + } + + var status = storageApi.SortedSetIntersect(keys, weights, aggregateType, out var result); + + switch (status) + { + case GarnetStatus.WRONGTYPE: + while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_WRONG_TYPE, ref dcurr, dend)) + SendAndReset(); + break; + default: + if (result == null || result.Count == 0) + { + while (!RespWriteUtils.WriteEmptyArray(ref dcurr, dend)) + SendAndReset(); + break; + } + + // write the size of the array reply + while (!RespWriteUtils.WriteArrayLength(includeWithScores ? result.Count * 2 : result.Count, ref dcurr, dend)) + SendAndReset(); + + foreach (var (element, score) in result) + { + while (!RespWriteUtils.WriteBulkString(element, ref dcurr, dend)) + SendAndReset(); + + if (includeWithScores) + { + while (!RespWriteUtils.TryWriteDoubleBulkString(score, ref dcurr, dend)) + SendAndReset(); + } + } + break; + } + + return true; + } + + /// + /// Returns the cardinality of the intersection between multiple sorted sets. + /// + /// + /// + private unsafe bool SortedSetIntersectLength(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (parseState.Count < 2) + { + return AbortWithWrongNumberOfArguments(nameof(RespCommand.ZINTERCARD)); + } + + // Number of keys + if (!parseState.TryGetInt(0, out var nKeys)) + { + return AbortWithErrorMessage(CmdStrings.RESP_ERR_GENERIC_VALUE_IS_NOT_INTEGER); + } + + if (nKeys < 1) + { + return AbortWithErrorMessage(Encoding.ASCII.GetBytes(string.Format(CmdStrings.GenericErrAtLeastOneKey, nameof(RespCommand.ZINTERCARD)))); + } + + if (parseState.Count < nKeys + 1) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + + var keys = parseState.Parameters.Slice(1, nKeys); + + // Optional LIMIT argument + int? limit = null; + if (parseState.Count > nKeys + 1) + { + var limitArg = parseState.GetArgSliceByRef(nKeys + 1); + if (!limitArg.ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase(CmdStrings.LIMIT) || parseState.Count != nKeys + 3) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + + if (!parseState.TryGetInt(nKeys + 2, out var limitVal)) + { + return AbortWithErrorMessage(CmdStrings.RESP_ERR_GENERIC_VALUE_IS_NOT_INTEGER); + } + + if (limitVal < 0) + { + return AbortWithErrorMessage(Encoding.ASCII.GetBytes(string.Format(CmdStrings.GenericErrCantBeNegative, "LIMIT"))); + } + + limit = limitVal; + } + + var status = storageApi.SortedSetIntersectLength(keys, limit, out var count); + + switch (status) + { + case GarnetStatus.WRONGTYPE: + while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_WRONG_TYPE, ref dcurr, dend)) + SendAndReset(); + break; + default: + while (!RespWriteUtils.WriteInteger(count, ref dcurr, dend)) + SendAndReset(); + break; + } + + return true; + } + + /// + /// Computes an intersection operation between multiple sorted sets and store + /// the result in the destination key. + /// The total number of input keys is specified. + /// + /// + /// + private unsafe bool SortedSetIntersectStore(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (parseState.Count < 3) + { + return AbortWithWrongNumberOfArguments(nameof(RespCommand.ZINTERSTORE)); + } + + // Number of keys + if (!parseState.TryGetInt(1, out var nKeys)) + { + return AbortWithErrorMessage(CmdStrings.RESP_ERR_GENERIC_VALUE_IS_NOT_INTEGER); + } + + var destination = parseState.GetArgSliceByRef(0); + var keys = parseState.Parameters.Slice(2, nKeys); + + double[] weights = null; + var aggregateType = SortedSetAggregateType.Sum; + var currentArg = nKeys + 2; + + // Parse optional arguments + while (currentArg < parseState.Count) + { + var arg = parseState.GetArgSliceByRef(currentArg).ReadOnlySpan; + + if (arg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.WEIGHTS)) + { + currentArg++; + if (currentArg + nKeys > parseState.Count) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + + weights = new double[nKeys]; + for (var i = 0; i < nKeys; i++) + { + if (!parseState.TryGetDouble(currentArg + i, out weights[i])) + { + return AbortWithErrorMessage(Encoding.ASCII.GetBytes(string.Format(CmdStrings.GenericErrNotAFloat, "weight"))); + } + } + currentArg += nKeys; + } + else if (arg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.AGGREGATE)) + { + currentArg++; + if (currentArg >= parseState.Count) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + + if (!parseState.TryGetSortedSetAggregateType(currentArg, out aggregateType)) + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + currentArg++; + } + else + { + return AbortWithErrorMessage(CmdStrings.RESP_SYNTAX_ERROR); + } + } + + var status = storageApi.SortedSetIntersectStore(destination, keys, weights, aggregateType, out var count); + + switch (status) + { + case GarnetStatus.WRONGTYPE: + while (!RespWriteUtils.WriteError(CmdStrings.RESP_ERR_WRONG_TYPE, ref dcurr, dend)) + SendAndReset(); + break; + default: + while (!RespWriteUtils.WriteInteger(count, ref dcurr, dend)) + SendAndReset(); + break; + } + + return true; + } } } \ No newline at end of file diff --git a/libs/server/Resp/Parser/RespCommand.cs b/libs/server/Resp/Parser/RespCommand.cs index 5cf5077592..abd425c901 100644 --- a/libs/server/Resp/Parser/RespCommand.cs +++ b/libs/server/Resp/Parser/RespCommand.cs @@ -75,6 +75,8 @@ public enum RespCommand : ushort ZCARD, ZCOUNT, ZDIFF, + ZINTER, + ZINTERCARD, ZLEXCOUNT, ZMSCORE, ZRANDMEMBER, @@ -162,6 +164,7 @@ public enum RespCommand : ushort ZDIFFSTORE, ZINCRBY, ZMPOP, + ZINTERSTORE, ZPOPMAX, ZPOPMIN, ZRANGESTORE, @@ -1220,6 +1223,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.ZSCORE; } + if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("ZINTER\r\n"u8)) + { + return RespCommand.ZINTER; + } break; } @@ -1429,6 +1436,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.BRPOPLPUSH; } + else if (*(ulong*)(ptr + 1) == MemoryMarshal.Read("10\r\nZINT"u8) && *(ulong*)(ptr + 9) == MemoryMarshal.Read("ERCARD\r\n"u8)) + { + return RespCommand.ZINTERCARD; + } break; case 11: @@ -1468,6 +1479,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.ZRANGESTORE; } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("1\r\nZINTE"u8) && *(ulong*)(ptr + 10) == MemoryMarshal.Read("RSTORE\r\n"u8)) + { + return RespCommand.ZINTERSTORE; + } break; case 12: diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index a5b00f020c..50b781d122 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -623,6 +623,9 @@ private bool ProcessArrayCommands(RespCommand cmd, ref TGarnetApi st RespCommand.ZREVRANGEBYLEX => SortedSetRange(cmd, ref storageApi), RespCommand.ZREVRANGEBYSCORE => SortedSetRange(cmd, ref storageApi), RespCommand.ZSCAN => ObjectScan(GarnetObjectType.SortedSet, ref storageApi), + RespCommand.ZINTER => SortedSetIntersect(ref storageApi), + RespCommand.ZINTERCARD => SortedSetIntersectLength(ref storageApi), + RespCommand.ZINTERSTORE => SortedSetIntersectStore(ref storageApi), //SortedSet for Geo Commands RespCommand.GEOADD => GeoAdd(ref storageApi), RespCommand.GEOHASH => GeoCommands(cmd, ref storageApi), diff --git a/libs/server/SessionParseStateExtensions.cs b/libs/server/SessionParseStateExtensions.cs index 6a3ac203f1..8ac469ae72 100644 --- a/libs/server/SessionParseStateExtensions.cs +++ b/libs/server/SessionParseStateExtensions.cs @@ -213,5 +213,28 @@ internal static bool TryGetExpireOption(this SessionParseState parseState, int i return true; } + + /// + /// Parse sorted set aggregate type from parse state at specified index + /// + /// The parse state + /// The argument index + /// Parsed value + /// True if value parsed successfully + internal static bool TryGetSortedSetAggregateType(this SessionParseState parseState, int idx, out SortedSetAggregateType value) + { + value = default; + var sbArg = parseState.GetArgSliceByRef(idx).ReadOnlySpan; + + if (sbArg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.SUM)) + value = SortedSetAggregateType.Sum; + else if (sbArg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.MIN)) + value = SortedSetAggregateType.Min; + else if (sbArg.EqualsUpperCaseSpanIgnoringCase(CmdStrings.MAX)) + value = SortedSetAggregateType.Max; + else return false; + + return true; + } } } \ No newline at end of file diff --git a/libs/server/SortedSetAggregateType.cs b/libs/server/SortedSetAggregateType.cs new file mode 100644 index 0000000000..b0410f0c27 --- /dev/null +++ b/libs/server/SortedSetAggregateType.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Garnet.server +{ + /// + /// Specifies the type of aggregation to be used in sorted set operations. + /// + public enum SortedSetAggregateType : byte + { + /// + /// Sum the values. + /// + Sum, + /// + /// Use the minimum value. + /// + Min, + /// + /// Use the maximum value. + /// + Max + } +} \ No newline at end of file diff --git a/libs/server/Storage/Session/ObjectStore/SortedSetOps.cs b/libs/server/Storage/Session/ObjectStore/SortedSetOps.cs index 72b7896be9..7f542e3bf0 100644 --- a/libs/server/Storage/Session/ObjectStore/SortedSetOps.cs +++ b/libs/server/Storage/Session/ObjectStore/SortedSetOps.cs @@ -1088,5 +1088,202 @@ public unsafe GarnetStatus SortedSetMPop(ReadOnlySpan keys, int count, txnManager.Commit(true); } } + + /// + /// Computes the cardinality of the intersection of multiple sorted sets. + /// + public GarnetStatus SortedSetIntersectLength(ReadOnlySpan keys, int? limit, out int count) + { + count = 0; + + var status = SortedSetIntersect(keys, null, SortedSetAggregateType.Sum, out var pairs); + if (status == GarnetStatus.OK && pairs != null) + { + count = limit.HasValue ? Math.Min(pairs.Count, limit.Value) : pairs.Count; + } + + return status; + } + + /// + /// Computes the intersection of multiple sorted sets and stores the resulting sorted set at destinationKey. + /// + public GarnetStatus SortedSetIntersectStore(ArgSlice destinationKey, ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out int count) + { + count = default; + + if (keys.Length == 0) + return GarnetStatus.OK; + + var createTransaction = false; + + if (txnManager.state != TxnState.Running) + { + Debug.Assert(txnManager.state == TxnState.None); + createTransaction = true; + txnManager.SaveKeyEntryToLock(destinationKey, true, LockType.Exclusive); + foreach (var item in keys) + txnManager.SaveKeyEntryToLock(item, true, LockType.Shared); + _ = txnManager.Run(true); + } + + var objectContext = txnManager.ObjectStoreLockableContext; + + try + { + var status = SortedSetIntersection(keys, weights, aggregateType, ref objectContext, out var pairs); + + if (status != GarnetStatus.OK) + { + return GarnetStatus.WRONGTYPE; + } + + count = pairs?.Count ?? 0; + + if (count > 0) + { + SortedSetObject newSortedSetObject = new(); + foreach (var (element, score) in pairs) + { + newSortedSetObject.Add(element, score); + } + _ = SET(destinationKey.ToArray(), newSortedSetObject, ref objectContext); + } + else + { + _ = EXPIRE(destinationKey, TimeSpan.Zero, out _, StoreType.Object, ExpireOption.None, + ref lockableContext, ref objectContext); + } + + return status; + } + finally + { + if (createTransaction) + txnManager.Commit(true); + } + } + + /// + /// Computes the intersection of multiple sorted sets and returns the result with optional weights and aggregate type. + /// + public GarnetStatus SortedSetIntersect(ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, out Dictionary pairs) + { + pairs = default; + + if (keys.Length == 0) + return GarnetStatus.OK; + + var createTransaction = false; + + if (txnManager.state != TxnState.Running) + { + Debug.Assert(txnManager.state == TxnState.None); + createTransaction = true; + foreach (var item in keys) + txnManager.SaveKeyEntryToLock(item, true, LockType.Shared); + txnManager.Run(true); + } + + var objectContext = txnManager.ObjectStoreLockableContext; + + try + { + return SortedSetIntersection(keys, weights, aggregateType, ref objectContext, out pairs); + } + finally + { + if (createTransaction) + txnManager.Commit(true); + } + } + + /// + /// Computes the intersection of multiple sorted sets and returns the result with optional weights and aggregate type. + /// + /// The type of the object context. + /// The keys of the sorted sets to intersect. + /// The weights to apply to each sorted set's scores. If null, no weights are applied. + /// The type of aggregation to use (Sum, Min, Max). + /// The object context. + /// The resulting dictionary of intersected elements and their scores. + /// + private GarnetStatus SortedSetIntersection(ReadOnlySpan keys, double[] weights, SortedSetAggregateType aggregateType, ref TObjectContext objectContext, out Dictionary pairs) + where TObjectContext : ITsavoriteContext + { + pairs = default; + + var statusOp = GET(keys[0].ToArray(), out var firstObj, ref objectContext); + if (statusOp == GarnetStatus.OK) + { + if (firstObj.garnetObject is not SortedSetObject firstSortedSet) + { + return GarnetStatus.WRONGTYPE; + } + + if (keys.Length == 1) + { + pairs = firstSortedSet.Dictionary; + return GarnetStatus.OK; + } + + // Initialize result with first set + if (weights is null) + { + pairs = new Dictionary(firstSortedSet.Dictionary, ByteArrayComparer.Instance); + } + else + { + pairs = new Dictionary(ByteArrayComparer.Instance); + foreach (var kvp in firstSortedSet.Dictionary) + { + pairs[kvp.Key] = kvp.Value * weights[0]; + } + } + + // Intersect with remaining sets + for (var i = 1; i < keys.Length; i++) + { + statusOp = GET(keys[i].ToArray(), out var nextObj, ref objectContext); + if (statusOp != GarnetStatus.OK) + { + pairs = default; + return statusOp; + } + + if (nextObj.garnetObject is not SortedSetObject nextSortedSet) + { + pairs = default; + return GarnetStatus.WRONGTYPE; + } + + foreach (var kvp in pairs) + { + if (!nextSortedSet.Dictionary.TryGetValue(kvp.Key, out var score)) + { + pairs.Remove(kvp.Key); + continue; + } + + var weightedScore = weights is null ? score : score * weights[i]; + pairs[kvp.Key] = aggregateType switch + { + SortedSetAggregateType.Sum => kvp.Value + weightedScore, + SortedSetAggregateType.Min => Math.Min(kvp.Value, weightedScore), + SortedSetAggregateType.Max => Math.Max(kvp.Value, weightedScore), + _ => kvp.Value + weightedScore // Default to SUM + }; + } + + // If intersection becomes empty, we can stop early + if (pairs.Count == 0) + { + break; + } + } + } + + return GarnetStatus.OK; + } } } \ No newline at end of file diff --git a/playground/CommandInfoUpdater/SupportedCommand.cs b/playground/CommandInfoUpdater/SupportedCommand.cs index c40c688f96..b7dfee8e76 100644 --- a/playground/CommandInfoUpdater/SupportedCommand.cs +++ b/playground/CommandInfoUpdater/SupportedCommand.cs @@ -269,6 +269,9 @@ public class SupportedCommand new("ZDIFF", RespCommand.ZDIFF), new("ZDIFFSTORE", RespCommand.ZDIFFSTORE), new("ZINCRBY", RespCommand.ZINCRBY), + new("ZINTER", RespCommand.ZINTER), + new("ZINTERCARD", RespCommand.ZINTERCARD), + new("ZINTERSTORE", RespCommand.ZINTERSTORE), new("ZLEXCOUNT", RespCommand.ZLEXCOUNT), new("ZMSCORE", RespCommand.ZMSCORE), new("ZMPOP", RespCommand.ZMPOP), diff --git a/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs b/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs index 26e41a89ca..e323355f15 100644 --- a/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs +++ b/test/Garnet.test.cluster/RedirectTests/BaseCommand.cs @@ -2049,6 +2049,93 @@ public override ArraySegment[] SetupSingleSlotRequest() } } + internal class ZINTER : BaseCommand + { + public override bool IsArrayCommand => true; + public override bool ArrayResponse => true; + public override string Command => nameof(ZINTER); + + public override string[] GetSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + return ["2", ssk[0], ssk[1]]; + } + + public override string[] GetCrossSlotRequest() + { + var csk = GetCrossSlotKeys; + return ["2", csk[0], csk[1]]; + } + + public override ArraySegment[] SetupSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + var setup = new ArraySegment[3]; + setup[0] = new ArraySegment(["ZADD", ssk[1], "1", "a"]); + setup[1] = new ArraySegment(["ZADD", ssk[2], "2", "b"]); + setup[2] = new ArraySegment(["ZADD", ssk[3], "3", "c"]); + return setup; + } + } + + internal class ZINTERCARD : BaseCommand + { + public override bool IsArrayCommand => true; + public override bool ArrayResponse => false; + public override string Command => nameof(ZINTERCARD); + + public override string[] GetSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + return ["2", ssk[0], ssk[1]]; + } + + public override string[] GetCrossSlotRequest() + { + var csk = GetCrossSlotKeys; + return ["2", csk[0], csk[1]]; + } + + public override ArraySegment[] SetupSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + var setup = new ArraySegment[3]; + setup[0] = new ArraySegment(["ZADD", ssk[1], "1", "a"]); + setup[1] = new ArraySegment(["ZADD", ssk[2], "2", "b"]); + setup[2] = new ArraySegment(["ZADD", ssk[3], "3", "c"]); + return setup; + } + } + + internal class ZINTERSTORE : BaseCommand + { + public override bool IsArrayCommand => true; + public override bool ArrayResponse => false; + public override string Command => nameof(ZINTERSTORE); + + public override string[] GetSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + return [ssk[0], "2", ssk[1], ssk[2]]; + } + + public override string[] GetCrossSlotRequest() + { + var csk = GetCrossSlotKeys; + return [csk[0], "2", csk[1], csk[2]]; + } + + public override ArraySegment[] SetupSingleSlotRequest() + { + var ssk = GetSingleSlotKeys; + var setup = new ArraySegment[3]; + setup[0] = new ArraySegment(["ZADD", ssk[1], "1", "a"]); + setup[1] = new ArraySegment(["ZADD", ssk[2], "2", "b"]); + setup[2] = new ArraySegment(["ZADD", ssk[3], "3", "c"]); + return setup; + } + } + internal class ZRANGESTORE : BaseCommand { public override bool IsArrayCommand => true; diff --git a/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs b/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs index f603176c78..2a82ff4757 100644 --- a/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs +++ b/test/Garnet.test.cluster/RedirectTests/ClusterSlotVerificationTests.cs @@ -111,6 +111,9 @@ public class ClusterSlotVerificationTests new ZRANDMEMBER(), new ZDIFF(), new ZDIFFSTORE(), + new ZINTER(), + new ZINTERCARD(), + new ZINTERSTORE(), new HSET(), new HGET(), new HGETALL(), @@ -295,6 +298,9 @@ public virtual void OneTimeTearDown() [TestCase("ZRANDMEMBER")] [TestCase("ZDIFF")] [TestCase("ZDIFFSTORE")] + [TestCase("ZINTER")] + [TestCase("ZINTERCARD")] + [TestCase("ZINTERSTORE")] [TestCase("HSET")] [TestCase("HGET")] [TestCase("HGETALL")] @@ -443,6 +449,9 @@ void GarnetClientSessionClusterDown(BaseCommand command) [TestCase("ZRANDMEMBER")] [TestCase("ZDIFF")] [TestCase("ZDIFFSTORE")] + [TestCase("ZINTER")] + [TestCase("ZINTERCARD")] + [TestCase("ZINTERSTORE")] [TestCase("HSET")] [TestCase("HGET")] [TestCase("HGETALL")] @@ -600,6 +609,9 @@ void GarnetClientSessionOK(BaseCommand command) [TestCase("ZRANDMEMBER")] [TestCase("ZDIFF")] [TestCase("ZDIFFSTORE")] + [TestCase("ZINTER")] + [TestCase("ZINTERCARD")] + [TestCase("ZINTERSTORE")] [TestCase("HSET")] [TestCase("HGET")] [TestCase("HGETALL")] @@ -748,6 +760,9 @@ void GarnetClientSessionCrossslotTest(BaseCommand command) [TestCase("ZRANDMEMBER")] [TestCase("ZDIFF")] [TestCase("ZDIFFSTORE")] + [TestCase("ZINTER")] + [TestCase("ZINTERCARD")] + [TestCase("ZINTERSTORE")] [TestCase("HSET")] [TestCase("HGET")] [TestCase("HGETALL")] @@ -904,6 +919,9 @@ void GarnetClientSessionMOVEDTest(BaseCommand command) [TestCase("ZRANDMEMBER")] [TestCase("ZDIFF")] [TestCase("ZDIFFSTORE")] + [TestCase("ZINTER")] + [TestCase("ZINTERCARD")] + [TestCase("ZINTERSTORE")] [TestCase("HSET")] [TestCase("HGET")] [TestCase("HGETALL")] @@ -1078,6 +1096,9 @@ void GarnetClientSessionASKTest(BaseCommand command) [TestCase("ZRANDMEMBER")] [TestCase("ZDIFF")] [TestCase("ZDIFFSTORE")] + [TestCase("ZINTER")] + [TestCase("ZINTERCARD")] + [TestCase("ZINTERSTORE")] [TestCase("HSET")] [TestCase("HGET")] [TestCase("HGETALL")] diff --git a/test/Garnet.test/Resp/ACL/RespCommandTests.cs b/test/Garnet.test/Resp/ACL/RespCommandTests.cs index cda4ccae7e..ec98b645ce 100644 --- a/test/Garnet.test/Resp/ACL/RespCommandTests.cs +++ b/test/Garnet.test/Resp/ACL/RespCommandTests.cs @@ -6148,6 +6148,51 @@ static async Task DoZDiffStoreAsync(GarnetClient client) } } + [Test] + public async Task ZInterACLsAsync() + { + await CheckCommandsAsync( + "ZINTER", + [DoZInterAsync] + ); + + static async Task DoZInterAsync(GarnetClient client) + { + var val = await client.ExecuteForStringArrayResultAsync("ZINTER", ["2", "foo", "bar"]); + ClassicAssert.AreEqual(0, val.Length); + } + } + + [Test] + public async Task ZInterCardACLsAsync() + { + await CheckCommandsAsync( + "ZINTERCARD", + [DoZInterCardAsync] + ); + + static async Task DoZInterCardAsync(GarnetClient client) + { + var val = await client.ExecuteForLongResultAsync("ZINTERCARD", ["2", "foo", "bar"]); + ClassicAssert.AreEqual(0, val); + } + } + + [Test] + public async Task ZInterStoreACLsAsync() + { + await CheckCommandsAsync( + "ZINTERSTORE", + [DoZInterStoreAsync] + ); + + static async Task DoZInterStoreAsync(GarnetClient client) + { + var val = await client.ExecuteForLongResultAsync("ZINTERSTORE", ["keyZ", "2", "foo", "bar"]); + ClassicAssert.AreEqual(0, val); + } + } + [Test] public async Task ZScanACLsAsync() { diff --git a/test/Garnet.test/RespSortedSetTests.cs b/test/Garnet.test/RespSortedSetTests.cs index 13da291281..1742d6d107 100644 --- a/test/Garnet.test/RespSortedSetTests.cs +++ b/test/Garnet.test/RespSortedSetTests.cs @@ -1347,6 +1347,218 @@ public void SortedSetMultiPopWithFirstKeyEmptyOnSecondPopTest() ClassicAssert.AreEqual("board2", (string)popResult2[0]); } + [Test] + public void CanDoZInterWithSE() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Setup test data + db.SortedSetAdd("zset1", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("three", 3) + ]); + + db.SortedSetAdd("zset2", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("four", 4) + ]); + + db.SortedSetAdd("zset3", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("three", 3), + new SortedSetEntry("five", 5) + ]); + + // Test basic intersection + var result = db.SortedSetCombine(SetOperation.Intersect, [new RedisKey("zset1"), new RedisKey("zset2")]); + ClassicAssert.AreEqual(2, result.Length); + ClassicAssert.AreEqual("one", result[0].ToString()); + ClassicAssert.AreEqual("two", result[1].ToString()); + + // Test three-way intersection + result = db.SortedSetCombine(SetOperation.Intersect, [new RedisKey("zset1"), new RedisKey("zset2"), new RedisKey("zset3")]); + ClassicAssert.AreEqual(1, result.Length); + ClassicAssert.AreEqual("one", result[0].ToString()); + + // Test with scores + var resultWithScores = db.SortedSetCombineWithScores(SetOperation.Intersect, [new RedisKey("zset1"), new RedisKey("zset2")]); + ClassicAssert.AreEqual(2, resultWithScores.Length); + ClassicAssert.AreEqual("one", resultWithScores[0].Element.ToString()); + ClassicAssert.AreEqual(2, resultWithScores[0].Score); + ClassicAssert.AreEqual("two", resultWithScores[1].Element.ToString()); + ClassicAssert.AreEqual(4, resultWithScores[1].Score); + } + + [Test] + [TestCase(2, "ZINTER 2 zset1 zset2", new[] { "one", "two" }, new[] { 2.0, 4.0 }, Description = "Basic intersection")] + [TestCase(3, "ZINTER 3 zset1 zset2 zset3", new[] { "one" }, new[] { 3.0 }, Description = "Three-way intersection")] + [TestCase(2, "ZINTER 2 zset1 zset2 WITHSCORES", new[] { "one", "two" }, new[] { 2.0, 4.0 }, Description = "With scores")] + [TestCase(2, "ZINTER 2 zset1 zset2 WEIGHTS 2 3 WITHSCORES", new[] { "one", "two" }, new[] { 5.0, 10.0 }, Description = "With weights 2,3 multiplied by scores")] + [TestCase(2, "ZINTER 2 zset1 zset2 AGGREGATE MAX WITHSCORES", new[] { "one", "two" }, new[] { 1.0, 2.0 }, Description = "Using maximum of scores")] + [TestCase(2, "ZINTER 2 zset1 zset2 AGGREGATE MIN WITHSCORES", new[] { "one", "two" }, new[] { 1.0, 2.0 }, Description = "Using minimum of scores")] + [TestCase(2, "ZINTER 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM WITHSCORES", new[] { "one", "two" }, new[] { 5.0, 10.0 }, Description = "Weights with sum aggregation")] + public void CanDoZInterWithSE(int numKeys, string command, string[] expectedValues, double[] expectedScores) + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Setup test data + db.SortedSetAdd("zset1", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("three", 3) + ]); + + db.SortedSetAdd("zset2", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("four", 4) + ]); + + db.SortedSetAdd("zset3", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("three", 3), + new SortedSetEntry("five", 5) + ]); + + // Test intersection operation + if (command.Contains("WITHSCORES")) + { + var resultWithScores = db.SortedSetCombineWithScores(SetOperation.Intersect, + command.Contains("WEIGHTS") ? [new RedisKey("zset1"), new RedisKey("zset2")] : + Enumerable.Range(1, numKeys).Select(i => new RedisKey($"zset{i}")).ToArray(), + command.Contains("WEIGHTS") ? [2.0, 3.0] : null, + command.Contains("MAX") ? Aggregate.Max : + command.Contains("MIN") ? Aggregate.Min : Aggregate.Sum); + + ClassicAssert.AreEqual(expectedValues.Length, resultWithScores.Length); + for (int i = 0; i < expectedValues.Length; i++) + { + ClassicAssert.AreEqual(expectedValues[i], resultWithScores[i].Element.ToString()); + ClassicAssert.AreEqual(expectedScores[i], resultWithScores[i].Score); + } + } + else + { + var result = db.SortedSetCombine(SetOperation.Intersect, + Enumerable.Range(1, numKeys).Select(i => new RedisKey($"zset{i}")).ToArray()); + + ClassicAssert.AreEqual(expectedValues.Length, result.Length); + for (int i = 0; i < expectedValues.Length; i++) + { + ClassicAssert.AreEqual(expectedValues[i], result[i].ToString()); + } + } + } + + [Test] + public void CanDoZInterCardWithSE() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Setup test data + db.SortedSetAdd("zset1", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("three", 3) + ]); + + db.SortedSetAdd("zset2", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("four", 4) + ]); + + db.SortedSetAdd("zset3", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("three", 3), + new SortedSetEntry("five", 5) + ]); + + // Test basic intersection cardinality + var result = (long)db.Execute("ZINTERCARD", "2", "zset1", "zset2"); + ClassicAssert.AreEqual(2, result); + + // Test three-way intersection cardinality + result = (long)db.Execute("ZINTERCARD", "3", "zset1", "zset2", "zset3"); + ClassicAssert.AreEqual(1, result); + + // Test with limit + result = (long)db.Execute("ZINTERCARD", "2", "zset1", "zset2", "LIMIT", "1"); + ClassicAssert.AreEqual(1, result); + } + + [Test] + public void CanDoZInterStoreWithSE() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Setup test data + db.SortedSetAdd("zset1", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("three", 3) + ]); + + db.SortedSetAdd("zset2", + [ + new SortedSetEntry("one", 1), + new SortedSetEntry("two", 2), + new SortedSetEntry("four", 4) + ]); + + // Test basic intersection store + var result = db.SortedSetCombineAndStore(SetOperation.Intersect, "dest", [new RedisKey("zset1"), new RedisKey("zset2")]); + ClassicAssert.AreEqual(2, result); + + var storedValues = db.SortedSetRangeByScoreWithScores("dest"); + ClassicAssert.AreEqual(2, storedValues.Length); + ClassicAssert.AreEqual("one", storedValues[0].Element.ToString()); + ClassicAssert.AreEqual(2, storedValues[0].Score); // Sum of scores + ClassicAssert.AreEqual("two", storedValues[1].Element.ToString()); + ClassicAssert.AreEqual(4, storedValues[1].Score); // Sum of scores + + // Test with weights + var weights = new double[] { 2, 3 }; + result = db.SortedSetCombineAndStore(SetOperation.Intersect, "dest", [new RedisKey("zset1"), new RedisKey("zset2")], weights); + ClassicAssert.AreEqual(2, result); + + storedValues = db.SortedSetRangeByScoreWithScores("dest"); + ClassicAssert.AreEqual(2, storedValues.Length); + ClassicAssert.AreEqual("one", storedValues[0].Element.ToString()); + ClassicAssert.AreEqual(5, storedValues[0].Score); // Weighted sum + ClassicAssert.AreEqual("two", storedValues[1].Element.ToString()); + ClassicAssert.AreEqual(10, storedValues[1].Score); // Weighted sum + + // Test with MAX aggregate + var result2 = (long)db.Execute("ZINTERSTORE", "dest", "2", "zset1", "zset2", "AGGREGATE", "MAX"); + ClassicAssert.AreEqual(2, result2); + + storedValues = db.SortedSetRangeByScoreWithScores("dest"); + ClassicAssert.AreEqual(2, storedValues.Length); + ClassicAssert.AreEqual(1, storedValues[0].Score); // MAX of scores + ClassicAssert.AreEqual(2, storedValues[1].Score); // MAX of scores + + // Test error cases + var ex = Assert.Throws(() => db.Execute("ZINTERSTORE", "dest")); + ClassicAssert.AreEqual(string.Format(CmdStrings.GenericErrWrongNumArgs, "ZINTERSTORE"), ex.Message); + } + #endregion #region LightClientTests @@ -2436,6 +2648,7 @@ public void CanUseZDiff(int bytesSent) zdiffResult = lightClientRequest.SendCommandChunks("ZDIFF 2 dadi seconddadi", bytesSent, 3); expectedResponse = "*2\r\n$6\r\ncinque\r\n$3\r\nsei\r\n"; actualValue = Encoding.ASCII.GetString(zdiffResult).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); } [Test] @@ -2893,6 +3106,105 @@ private static void UpdateSortedSetKey(string keyName) } private static string FormatWrongNumOfArgsError(string commandName) => $"-{string.Format(CmdStrings.GenericErrWrongNumArgs, commandName)}\r\n"; + + [Test] + [TestCase(2, "ZINTER 2 zset1 zset2", Description = "Basic intersection")] + [TestCase(3, "ZINTER 3 zset1 zset2 zset3", Description = "Three-way intersection")] + [TestCase(2, "ZINTER 2 zset1 zset2 WITHSCORES", Description = "With scores")] + public void CanDoZInter(int numKeys, string command) + { + using var lightClientRequest = TestUtils.CreateRequest(); + + // Setup test data + lightClientRequest.SendCommand("ZADD zset1 1 one 2 two 3 three"); + lightClientRequest.SendCommand("ZADD zset2 1 one 2 two 4 four"); + lightClientRequest.SendCommand("ZADD zset3 1 one 3 three 5 five"); + + var response = lightClientRequest.SendCommand(command); + if (command.Contains("WITHSCORES")) + { + if (numKeys == 2) + { + var expectedResponse = "*4\r\n$3\r\none\r\n$1\r\n2\r\n$3\r\ntwo\r\n$1\r\n4\r\n"; + var actualValue = Encoding.ASCII.GetString(response).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); + } + } + else + { + if (numKeys == 2) + { + var expectedResponse = "*2\r\n$3\r\none\r\n$3\r\ntwo\r\n"; + var actualValue = Encoding.ASCII.GetString(response).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); + } + else if (numKeys == 3) + { + var expectedResponse = "*1\r\n$3\r\none\r\n"; + var actualValue = Encoding.ASCII.GetString(response).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); + } + } + } + + [Test] + [TestCase("ZINTERCARD 2 zset1 zset2", 2, Description = "Basic intersection cardinality")] + [TestCase("ZINTERCARD 3 zset1 zset2 zset3", 1, Description = "Three-way intersection cardinality")] + [TestCase("ZINTERCARD 2 zset1 zset2 LIMIT 1", 1, Description = "With limit")] + public void CanDoZInterCard(string command, int expectedCount) + { + using var lightClientRequest = TestUtils.CreateRequest(); + + // Setup test data + lightClientRequest.SendCommand("ZADD zset1 1 one 2 two 3 three"); + lightClientRequest.SendCommand("ZADD zset2 1 one 2 two 4 four"); + lightClientRequest.SendCommand("ZADD zset3 1 one 3 three 5 five"); + + var response = lightClientRequest.SendCommand(command); + var expectedResponse = $":{expectedCount}\r\n"; + var actualValue = Encoding.ASCII.GetString(response).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); + } + + [Test] + [TestCase("ZINTERSTORE dest 2 zset1 zset2", 2, Description = "Basic intersection store")] + [TestCase("ZINTERSTORE dest 2 zset1 zset2 WEIGHTS 2 3", 2, Description = "With weights")] + [TestCase("ZINTERSTORE dest 2 zset1 zset2 AGGREGATE MAX", 2, Description = "With MAX aggregation")] + [TestCase("ZINTERSTORE dest 2 zset1 zset2 AGGREGATE MIN", 2, Description = "With MIN aggregation")] + public void CanDoZInterStore(string command, int expectedCount) + { + using var lightClientRequest = TestUtils.CreateRequest(); + + // Setup test data + lightClientRequest.SendCommand("ZADD zset1 1 one 2 two 3 three"); + lightClientRequest.SendCommand("ZADD zset2 1 one 2 two 4 four"); + + var response = lightClientRequest.SendCommand(command); + var expectedResponse = $":{expectedCount}\r\n"; + var actualValue = Encoding.ASCII.GetString(response).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); + + // Verify stored results + response = lightClientRequest.SendCommand("ZRANGE dest 0 -1 WITHSCORES"); + if (command.Contains("WEIGHTS")) + { + expectedResponse = "*4\r\n$3\r\none\r\n$1\r\n5\r\n$3\r\ntwo\r\n$2\r\n10\r\n"; + } + else if (command.Contains("MAX")) + { + expectedResponse = "*4\r\n$3\r\none\r\n$1\r\n1\r\n$3\r\ntwo\r\n$1\r\n2\r\n"; + } + else if (command.Contains("MIN")) + { + expectedResponse = "*4\r\n$3\r\none\r\n$1\r\n1\r\n$3\r\ntwo\r\n$1\r\n2\r\n"; + } + else + { + expectedResponse = "*4\r\n$3\r\none\r\n$1\r\n2\r\n$3\r\ntwo\r\n$1\r\n4\r\n"; + } + actualValue = Encoding.ASCII.GetString(response).Substring(0, expectedResponse.Length); + ClassicAssert.AreEqual(expectedResponse, actualValue); + } } public class SortedSetComparer : IComparer<(double, byte[])> diff --git a/website/docs/commands/api-compatibility.md b/website/docs/commands/api-compatibility.md index 3e9ba3f4bc..033eee5270 100644 --- a/website/docs/commands/api-compatibility.md +++ b/website/docs/commands/api-compatibility.md @@ -327,9 +327,9 @@ Note that this list is subject to change as we continue to expand our API comman | | [ZDIFF](data-structures.md#zdiff) | ➕ | | | | [ZDIFFSTORE](data-structures.md#zdiffstore) | ➕ | | | | [ZINCRBY](data-structures.md#zincrby) | ➕ | | -| | ZINTER | ➖ | | -| | ZINTERCARD | ➖ | | -| | ZINTERSTORE | ➖ | | +| | [ZINTER](data-structures.md#zinter) | ➕ | | +| | [ZINTERCARD](data-structures.md#zintercard) | ➕ | | +| | [ZINTERSTORE](data-structures.md#zinterstore) | ➕ | | | | [ZLEXCOUNT](data-structures.md#zlexcount) | ➕ | | | | [ZMPOP](data-structures.md#zmpop) | ➕ | | | | [ZMSCORE](data-structures.md#zmscore) | ➕ | | diff --git a/website/docs/commands/data-structures.md b/website/docs/commands/data-structures.md index fccfe83e2e..426cabb293 100644 --- a/website/docs/commands/data-structures.md +++ b/website/docs/commands/data-structures.md @@ -835,6 +835,43 @@ An error is returned when **key** exists but does not hold a sorted set. The score value should be the string representation of a numeric value, and accepts double precision floating point numbers. It is possible to provide a negative value to decrement the score. +--- + +### ZINTER + +#### Syntax + +```bash + ZINTER numkeys key [key ...] [WEIGHTS weight [weight ...]] [AGGREGATE ] [WITHSCORES] +``` + +Computes the intersection of the sorted sets given by the specified keys and returns the result. It is possible to specify multiple keys. + +The result is a new sorted set with the same elements as the input sets, but with scores equal to the sum of the scores of the elements in the input sets. + +--- + +### ZINTERCARD + +#### Syntax + +```bash + ZINTERCARD numkeys key [key ...] [LIMIT limit] +``` + +Returns the number of elements in the intersection of the sorted sets given by the specified keys. + +--- + +### ZINTERSTORE + +#### Syntax + +```bash + ZINTERSTORE destination numkeys key [key ...] [WEIGHTS weight [weight ...]] [AGGREGATE ] +``` + +Computes the intersection of the sorted sets given by the specified keys and stores the result in the destination key. --- @@ -1297,4 +1334,3 @@ This command is like [GEOSEARCH](#geosearch), but stores the result in destinati Integer reply: the number of elements in the resulting set --- -