Skip to content

Commit

Permalink
Implement ZAdd and Zscore (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
the123saurav authored and tuhuynh27 committed Dec 19, 2021
1 parent 93a0093 commit 6586a84
Show file tree
Hide file tree
Showing 14 changed files with 742 additions and 9 deletions.
1 change: 1 addition & 0 deletions core/src/main/java/dev/keva/core/aof/AOFContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public List<Command> read() throws IOException {
byte[][] objects = (byte[][]) input.readObject();
commands.add(Command.newInstance(objects, false));
} catch (EOFException e) {
log.error("Error while reading AOF command", e);
fis.close();
return commands;
} catch (ClassNotFoundException e) {
Expand Down
123 changes: 123 additions & 0 deletions core/src/main/java/dev/keva/core/command/impl/zset/ZAdd.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package dev.keva.core.command.impl.zset;

import dev.keva.core.command.annotation.CommandImpl;
import dev.keva.core.command.annotation.Execute;
import dev.keva.core.command.annotation.Mutate;
import dev.keva.core.command.annotation.ParamLength;
import dev.keva.ioc.annotation.Autowired;
import dev.keva.ioc.annotation.Component;
import dev.keva.protocol.resp.reply.BulkReply;
import dev.keva.protocol.resp.reply.ErrorReply;
import dev.keva.protocol.resp.reply.IntegerReply;
import dev.keva.protocol.resp.reply.Reply;
import dev.keva.store.KevaDatabase;
import dev.keva.util.DoubleUtil;
import dev.keva.util.hashbytes.BytesKey;

import java.nio.charset.StandardCharsets;
import java.util.AbstractMap.SimpleEntry;

import static dev.keva.util.Constants.FLAG_CH;
import static dev.keva.util.Constants.FLAG_GT;
import static dev.keva.util.Constants.FLAG_INCR;
import static dev.keva.util.Constants.FLAG_LT;
import static dev.keva.util.Constants.FLAG_NX;
import static dev.keva.util.Constants.FLAG_XX;

@Component
@CommandImpl("zadd")
@ParamLength(type = ParamLength.Type.AT_LEAST, value = 3)
@Mutate
public final class ZAdd {
private static final String XX = "xx";
private static final String NX = "nx";
private static final String GT = "gt";
private static final String LT = "lt";
private static final String INCR = "incr";
private static final String CH = "ch";

private final KevaDatabase database;

@Autowired
public ZAdd(KevaDatabase database) {
this.database = database;
}

@Execute
public Reply<?> execute(byte[][] params) {
// Parse the flags, if any
boolean xx = false, nx = false, gt = false, lt = false, incr = false;
int argPos = 1, flags = 0;
String arg;
while (argPos < params.length) {
arg = new String(params[argPos], StandardCharsets.UTF_8);
if (XX.equalsIgnoreCase(arg)) {
xx = true;
flags |= FLAG_XX;
} else if (NX.equalsIgnoreCase(arg)) {
nx = true;
flags |= FLAG_NX;
} else if (GT.equalsIgnoreCase(arg)) {
gt = true;
flags |= FLAG_GT;
} else if (LT.equalsIgnoreCase(arg)) {
lt = true;
flags |= FLAG_LT;
} else if (INCR.equalsIgnoreCase(arg)) {
incr = true;
flags |= FLAG_INCR;
} else if (CH.equalsIgnoreCase(arg)) {
flags |= FLAG_CH;
} else {
break;
}
++argPos;
}

int numMembers = params.length - argPos;
if (numMembers % 2 != 0) {
return ErrorReply.SYNTAX_ERROR;
}
numMembers /= 2;

if (nx && xx) {
return ErrorReply.ZADD_NX_XX_ERROR;
}
if ((gt && nx) || (lt && nx) || (gt && lt)) {
return ErrorReply.ZADD_GT_LT_NX_ERROR;
}
if (incr && numMembers > 1) {
return ErrorReply.ZADD_INCR_ERROR;
}

// Parse the key and value
final SimpleEntry<Double, BytesKey>[] members = new SimpleEntry[numMembers];
double score;
String rawScore;
for (int memberIndex = 0; memberIndex < numMembers; ++memberIndex) {
try {
rawScore = new String(params[argPos++], StandardCharsets.UTF_8);
if (rawScore.equalsIgnoreCase("inf") || rawScore.equalsIgnoreCase("infinity")
|| rawScore.equalsIgnoreCase("+inf") || rawScore.equalsIgnoreCase("+infinity")
) {
score = Double.POSITIVE_INFINITY;
} else if (rawScore.equalsIgnoreCase("-inf") || rawScore.equalsIgnoreCase("-infinity")) {
score = Double.NEGATIVE_INFINITY;
} else {
score = Double.parseDouble(rawScore);
}
} catch (final NumberFormatException ignored) {
// return on first bad input
return ErrorReply.ZADD_SCORE_FLOAT_ERROR;
}
members[memberIndex] = new SimpleEntry<>(score, new BytesKey(params[argPos++]));
}

if (incr) {
Double result = database.zincrby(params[0], members[0].getKey(), members[0].getValue(), flags);
return result == null ? BulkReply.NIL_REPLY : new BulkReply(DoubleUtil.toString(result));
}
int result = database.zadd(params[0], members, flags);
return new IntegerReply(result);
}
}
36 changes: 36 additions & 0 deletions core/src/main/java/dev/keva/core/command/impl/zset/ZScore.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package dev.keva.core.command.impl.zset;

import dev.keva.core.command.annotation.CommandImpl;
import dev.keva.core.command.annotation.Execute;
import dev.keva.core.command.annotation.ParamLength;
import dev.keva.ioc.annotation.Autowired;
import dev.keva.ioc.annotation.Component;
import dev.keva.protocol.resp.reply.BulkReply;
import dev.keva.store.KevaDatabase;

@Component
@CommandImpl("zscore")
@ParamLength(type = ParamLength.Type.EXACT, value = 2)
public final class ZScore {
private final KevaDatabase database;

@Autowired
public ZScore(KevaDatabase database) {
this.database = database;
}

@Execute
public BulkReply execute(byte[] key, byte[] member) {
final Double result = database.zscore(key, member);
if(result == null){
return BulkReply.NIL_REPLY;
}
if (result.equals(Double.POSITIVE_INFINITY)) {
return BulkReply.POSITIVE_INFINITY_REPLY;
}
if (result.equals(Double.NEGATIVE_INFINITY)) {
return BulkReply.NEGATIVE_INFINITY_REPLY;
}
return new BulkReply(result.toString());
}
}
2 changes: 1 addition & 1 deletion core/src/test/java/dev/keva/core/server/AOFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Server startServer(int port) throws Exception {
.persistence(false)
.aof(true)
.aofInterval(1000)
.workDirectory("./")
.workDirectory(System.getProperty("java.io.tmpdir"))
.build();
val server = KevaServer.of(config);
new Thread(() -> {
Expand Down
98 changes: 98 additions & 0 deletions core/src/test/java/dev/keva/core/server/AbstractServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import lombok.var;
import redis.clients.jedis.params.ZAddParams;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -827,6 +830,101 @@ void setrange() {
}
}

@Test
void zaddWithXXAndNXErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().xx().nx());
});
}

@Test
void zaddSingleWithNxAndGtErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().gt().nx());
});
}

@Test
void zaddSingleWithNxAndLtErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().nx());
});
}

@Test
void zaddSingleWithGtAndLtErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().gt());
});
}

@Test
void zaddSingleWithoutOptions() {
try {
var result = jedis.zadd("zset", 1.0, "val");
assertEquals(1, result);

result = jedis.zadd("zset", 1.0, "val");
assertEquals(0, result);
} catch (Exception e) {
fail(e);
}
}

@Test
void zaddMultipleWithoutOptions() {
try {
Map<String, Double> members = new HashMap<>();
int numMembers = 100;
for(int i=0; i<numMembers; ++i) {
members.put(Integer.toString(i), (double) i);
}
var result = jedis.zadd("zset", members);
assertEquals(numMembers, result);

result = jedis.zadd("zset", members);
assertEquals(0, result);
} catch (Exception e) {
fail(e);
}
}

@Test
void zaddCh() {
try {
var result = jedis.zadd("zset", 1.0, "mem", new ZAddParams().ch());
assertEquals(1, result);

result = jedis.zadd("zset", 1.0, "mem", new ZAddParams().ch());
assertEquals(0, result);

result = jedis.zadd("zset", 2.0, "mem", new ZAddParams().ch());
assertEquals(1, result);
} catch (Exception e) {
fail(e);
}
}

@Test
void zscoreNonExistingKey() {
val result = jedis.zscore("key", "mem");
assertNull(result);
}

@Test
void zscoreNonExistingMember() {
jedis.zadd("zset", 1.0, "mem");
val result = jedis.zscore("zset", "foo");
assertNull(result);
}

@Test
void zscoreExistingMember() {
jedis.zadd("zset", 1.0, "mem");
val result = jedis.zscore("zset", "mem");
assertEquals(result, 1.0);
}

@Test
void dumpAndRestore() {
try {
Expand Down
8 changes: 8 additions & 0 deletions docs/src/guide/overview/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ Implemented commands:

</details>

<details>
<summary>SortedSet</summary>

- ZADD
- ZSCORE

</details>

<details>
<summary>Pub/Sub</summary>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

public class BulkReply implements Reply<ByteBuf> {
public static final BulkReply NIL_REPLY = new BulkReply();
public static final BulkReply POSITIVE_INFINITY_REPLY = new BulkReply("inf");
public static final BulkReply NEGATIVE_INFINITY_REPLY = new BulkReply("-inf");

public static final char MARKER = '$';
private final ByteBuf bytes;
Expand All @@ -22,11 +24,7 @@ private BulkReply() {
}

public BulkReply(byte[] bytes) {
if (bytes.length == 0) {
this.bytes = Unpooled.EMPTY_BUFFER;
} else {
this.bytes = Unpooled.wrappedBuffer(bytes);
}
this.bytes = Unpooled.wrappedBuffer(bytes);
capacity = bytes.length;
}

Expand Down Expand Up @@ -59,7 +57,7 @@ public void write(ByteBuf os) throws IOException {
os.writeByte(MARKER);
os.writeBytes(numToBytes(capacity, true));
if (capacity > 0) {
os.writeBytes(bytes);
os.writeBytes(bytes.array());
os.writeBytes(CRLF);
}
if (capacity == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

public class ErrorReply implements Reply<String> {
public static final char MARKER = '-';
// Pre-defined errors
public static final ErrorReply SYNTAX_ERROR = new ErrorReply("ERR syntax error");
public static final ErrorReply ZADD_NX_XX_ERROR = new ErrorReply("ERR XX and NX options at the same time are not compatible");
public static final ErrorReply ZADD_GT_LT_NX_ERROR = new ErrorReply("GT, LT, and/or NX options at the same time are not compatible");
public static final ErrorReply ZADD_INCR_ERROR = new ErrorReply("INCR option supports a single increment-element pair");
public static final ErrorReply ZADD_SCORE_FLOAT_ERROR = new ErrorReply("value is not a valid float");

private final String error;

public ErrorReply(String error) {
Expand Down
8 changes: 8 additions & 0 deletions store/src/main/java/dev/keva/store/KevaDatabase.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package dev.keva.store;

import dev.keva.util.hashbytes.BytesKey;

import java.util.AbstractMap;
import java.util.concurrent.locks.Lock;

public interface KevaDatabase {
Expand Down Expand Up @@ -69,4 +72,9 @@ public interface KevaDatabase {

byte[][] mget(byte[]... keys);

int zadd(byte[] key, AbstractMap.SimpleEntry<Double, BytesKey>[] members, int flags);

Double zincrby(byte[] key, Double score, BytesKey e, int flags);

Double zscore(byte[] key, byte[] member);
}
Loading

0 comments on commit 6586a84

Please sign in to comment.