diff --git a/pom.xml b/pom.xml
index a8930f2306..4ca1343f19 100644
--- a/pom.xml
+++ b/pom.xml
@@ -156,6 +156,13 @@
test
+
+ redis.clients.authentication
+ redis-authx-entraid
+ 0.1.0-SNAPSHOT
+ test
+
+
io.github.resilience4j
diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java
index ce4a10cb7b..cdbe2ab5c7 100644
--- a/src/main/java/redis/clients/jedis/ConnectionFactory.java
+++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java
@@ -11,7 +11,8 @@
import java.util.function.Supplier;
import redis.clients.jedis.annots.Experimental;
-import redis.clients.jedis.authentication.JedisAuthXManager;
+import redis.clients.jedis.authentication.AuthXManager;
+import redis.clients.jedis.authentication.AuthXEventListener;
import redis.clients.jedis.csc.Cache;
import redis.clients.jedis.csc.CacheConnection;
import redis.clients.jedis.exceptions.JedisException;
@@ -28,41 +29,43 @@ public class ConnectionFactory implements PooledObjectFactory {
private final Cache clientSideCache;
private final Supplier objectMaker;
+ private final AuthXEventListener authenticationEventListener;
+
public ConnectionFactory(final HostAndPort hostAndPort) {
- this(hostAndPort, DefaultJedisClientConfig.builder().build(), null, null);
+ this(hostAndPort, DefaultJedisClientConfig.builder().build(), null);
}
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) {
- this(hostAndPort, clientConfig, null, null);
+ this(hostAndPort, clientConfig, null);
}
@Experimental
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig,
- Cache csCache, JedisAuthXManager authXManager) {
- this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache,
- authXManager);
+ Cache csCache) {
+ this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache);
}
public ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
final JedisClientConfig clientConfig) {
- this(jedisSocketFactory, clientConfig, null, null);
+ this(jedisSocketFactory, clientConfig, null);
}
private ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
- final JedisClientConfig clientConfig, Cache csCache, JedisAuthXManager authXManager) {
+ final JedisClientConfig clientConfig, Cache csCache) {
this.jedisSocketFactory = jedisSocketFactory;
this.clientSideCache = csCache;
+ AuthXManager authXManager = clientConfig.getAuthXManager();
if (authXManager == null) {
this.clientConfig = clientConfig;
this.objectMaker = connectionSupplier();
+ this.authenticationEventListener = AuthXEventListener.NOOP_LISTENER;
} else {
- this.clientConfig = replaceCredentialsProvider(clientConfig,
- authXManager);
+ this.clientConfig = replaceCredentialsProvider(clientConfig, authXManager);
Supplier supplier = connectionSupplier();
this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get());
-
+ this.authenticationEventListener = authXManager.getListener();
try {
authXManager.start(true);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
@@ -114,7 +117,12 @@ public PooledObject makeObject() throws Exception {
public void passivateObject(PooledObject pooledConnection) throws Exception {
// TODO maybe should select db 0? Not sure right now.
Connection jedis = pooledConnection.getObject();
- jedis.reAuth();
+ try {
+ jedis.reAuth();
+ } catch (Exception e) {
+ authenticationEventListener.onConnectionAuthenticationError(e);
+ throw e;
+ }
}
@Override
@@ -122,7 +130,12 @@ public boolean validateObject(PooledObject pooledConnection) {
final Connection jedis = pooledConnection.getObject();
try {
// check HostAndPort ??
- jedis.reAuth();
+ try {
+ jedis.reAuth();
+ } catch (Exception e) {
+ authenticationEventListener.onConnectionAuthenticationError(e);
+ throw e;
+ }
return jedis.isConnected() && jedis.ping();
} catch (final Exception e) {
logger.warn("Error while validating pooled Connection object.", e);
diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java
index d7dc0d85f7..536b3a6484 100644
--- a/src/main/java/redis/clients/jedis/ConnectionPool.java
+++ b/src/main/java/redis/clients/jedis/ConnectionPool.java
@@ -4,36 +4,25 @@
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import redis.clients.jedis.annots.Experimental;
-import redis.clients.jedis.authentication.JedisAuthXManager;
+import redis.clients.jedis.authentication.AuthXManager;
import redis.clients.jedis.csc.Cache;
import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.util.Pool;
public class ConnectionPool extends Pool {
- private JedisAuthXManager authXManager;
+ private AuthXManager authXManager;
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
- this(hostAndPort, clientConfig, createAuthXManager(clientConfig));
- }
-
- public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
- JedisAuthXManager authXManager) {
- this(new ConnectionFactory(hostAndPort, clientConfig, null, authXManager));
- attachAuthenticationListener(authXManager);
+ this(new ConnectionFactory(hostAndPort, clientConfig));
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
@Experimental
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
Cache clientSideCache) {
- this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig));
- }
-
- @Experimental
- public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
- Cache clientSideCache, JedisAuthXManager authXManager) {
- this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager));
- attachAuthenticationListener(authXManager);
+ this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache));
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
public ConnectionPool(PooledObjectFactory factory) {
@@ -42,22 +31,15 @@ public ConnectionPool(PooledObjectFactory factory) {
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
GenericObjectPoolConfig poolConfig) {
- this(hostAndPort, clientConfig, null, createAuthXManager(clientConfig), poolConfig);
+ this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig);
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
@Experimental
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
Cache clientSideCache, GenericObjectPoolConfig poolConfig) {
- this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig), poolConfig);
- }
-
- @Experimental
- public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
- Cache clientSideCache, JedisAuthXManager authXManager,
- GenericObjectPoolConfig poolConfig) {
- this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager),
- poolConfig);
- attachAuthenticationListener(authXManager);
+ this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig);
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
public ConnectionPool(PooledObjectFactory factory,
@@ -80,17 +62,10 @@ public void close() {
super.close();
}
- private static JedisAuthXManager createAuthXManager(JedisClientConfig config) {
- if (config.getTokenAuthConfig() != null) {
- return new JedisAuthXManager(config.getTokenAuthConfig());
- }
- return null;
- }
-
- private void attachAuthenticationListener(JedisAuthXManager authXManager) {
+ private void attachAuthenticationListener(AuthXManager authXManager) {
this.authXManager = authXManager;
if (authXManager != null) {
- authXManager.setListener(token -> {
+ authXManager.addPostAuthenticationHook(token -> {
try {
// this is to trigger validations on each connection via ConnectionFactory
evict();
diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
index 8b161ca7ff..5f0e050ef4 100644
--- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
+++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
@@ -5,7 +5,7 @@
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
-import redis.clients.authentication.core.TokenAuthConfig;
+import redis.clients.jedis.authentication.AuthXManager;
public final class DefaultJedisClientConfig implements JedisClientConfig {
@@ -30,7 +30,7 @@ public final class DefaultJedisClientConfig implements JedisClientConfig {
private final boolean readOnlyForRedisClusterReplicas;
- private final TokenAuthConfig tokenAuthConfig;
+ private final AuthXManager authXManager;
private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMillis,
int soTimeoutMillis, int blockingSocketTimeoutMillis,
@@ -38,7 +38,7 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi
SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper,
ClientSetInfoConfig clientSetInfoConfig, boolean readOnlyForRedisClusterReplicas,
- TokenAuthConfig tokenAuthConfig) {
+ AuthXManager authXManager) {
this.redisProtocol = protocol;
this.connectionTimeoutMillis = connectionTimeoutMillis;
this.socketTimeoutMillis = soTimeoutMillis;
@@ -53,7 +53,8 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi
this.hostAndPortMapper = hostAndPortMapper;
this.clientSetInfoConfig = clientSetInfoConfig;
this.readOnlyForRedisClusterReplicas = readOnlyForRedisClusterReplicas;
- this.tokenAuthConfig = tokenAuthConfig;
+ this.authXManager = authXManager;
+
}
@Override
@@ -93,8 +94,8 @@ public Supplier getCredentialsProvider() {
}
@Override
- public TokenAuthConfig getTokenAuthConfig() {
- return tokenAuthConfig;
+ public AuthXManager getAuthXManager() {
+ return authXManager;
}
@Override
@@ -171,7 +172,7 @@ public static class Builder {
private boolean readOnlyForRedisClusterReplicas = false;
- private TokenAuthConfig tokenAuthConfig = null;
+ private AuthXManager authXManager;
private Builder() {
}
@@ -185,7 +186,7 @@ public DefaultJedisClientConfig build() {
return new DefaultJedisClientConfig(redisProtocol, connectionTimeoutMillis,
socketTimeoutMillis, blockingSocketTimeoutMillis, credentialsProvider, database,
clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper,
- clientSetInfoConfig, readOnlyForRedisClusterReplicas, tokenAuthConfig);
+ clientSetInfoConfig, readOnlyForRedisClusterReplicas, authXManager);
}
/**
@@ -287,8 +288,8 @@ public Builder readOnlyForRedisClusterReplicas() {
return this;
}
- public Builder tokenAuthConfig(TokenAuthConfig tokenAuthConfig) {
- this.tokenAuthConfig = tokenAuthConfig;
+ public Builder authXManager(AuthXManager authXManager) {
+ this.authXManager = authXManager;
return this;
}
@@ -307,7 +308,7 @@ public Builder from(JedisClientConfig instance) {
this.hostAndPortMapper = instance.getHostAndPortMapper();
this.clientSetInfoConfig = instance.getClientSetInfoConfig();
this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas();
- this.tokenAuthConfig = instance.getTokenAuthConfig();
+ this.authXManager = instance.getAuthXManager();
return this;
}
}
@@ -316,12 +317,12 @@ public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int s
int blockingSocketTimeoutMillis, String user, String password, int database,
String clientName, boolean ssl, SSLSocketFactory sslSocketFactory,
SSLParameters sslParameters, HostnameVerifier hostnameVerifier,
- HostAndPortMapper hostAndPortMapper, TokenAuthConfig tokenAuthConfig) {
+ HostAndPortMapper hostAndPortMapper, AuthXManager authXManager) {
return new DefaultJedisClientConfig(null, connectionTimeoutMillis, soTimeoutMillis,
blockingSocketTimeoutMillis,
new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(user, password)), database,
clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, null,
- false, tokenAuthConfig);
+ false, authXManager);
}
public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) {
@@ -330,6 +331,6 @@ public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) {
copy.getCredentialsProvider(), copy.getDatabase(), copy.getClientName(), copy.isSsl(),
copy.getSslSocketFactory(), copy.getSslParameters(), copy.getHostnameVerifier(),
copy.getHostAndPortMapper(), copy.getClientSetInfoConfig(),
- copy.isReadOnlyForRedisClusterReplicas(), copy.getTokenAuthConfig());
+ copy.isReadOnlyForRedisClusterReplicas(), copy.getAuthXManager());
}
}
diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java
index a8046694bf..82e9eb8e7f 100644
--- a/src/main/java/redis/clients/jedis/JedisClientConfig.java
+++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java
@@ -5,7 +5,7 @@
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
-import redis.clients.authentication.core.TokenAuthConfig;
+import redis.clients.jedis.authentication.AuthXManager;
public interface JedisClientConfig {
@@ -47,10 +47,11 @@ default String getPassword() {
}
default Supplier getCredentialsProvider() {
- return new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(getUser(), getPassword()));
+ return new DefaultRedisCredentialsProvider(
+ new DefaultRedisCredentials(getUser(), getPassword()));
}
- default TokenAuthConfig getTokenAuthConfig() {
+ default AuthXManager getAuthXManager() {
return null;
}
diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java
new file mode 100644
index 0000000000..4750404157
--- /dev/null
+++ b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java
@@ -0,0 +1,21 @@
+package redis.clients.jedis.authentication;
+
+public interface AuthXEventListener {
+
+ static AuthXEventListener NOOP_LISTENER = new AuthXEventListener() {
+
+ @Override
+ public void onIdentityProviderError(Exception reason) {
+ }
+
+ @Override
+ public void onConnectionAuthenticationError(Exception reason) {
+ }
+
+ };
+
+ public void onIdentityProviderError(Exception reason);
+
+ public void onConnectionAuthenticationError(Exception reason);
+
+}
diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java
similarity index 63%
rename from src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java
rename to src/main/java/redis/clients/jedis/authentication/AuthXManager.java
index a210f7fb42..d66bddeb4c 100644
--- a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java
+++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java
@@ -6,6 +6,7 @@
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
+import java.util.function.Consumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
@@ -18,25 +19,22 @@
import redis.clients.jedis.Connection;
import redis.clients.jedis.RedisCredentials;
-public class JedisAuthXManager implements Supplier {
+public final class AuthXManager implements Supplier {
- private static final Logger log = LoggerFactory.getLogger(JedisAuthXManager.class);
+ private static final Logger log = LoggerFactory.getLogger(AuthXManager.class);
private TokenManager tokenManager;
private List> connections = Collections
.synchronizedList(new ArrayList<>());
private Token currentToken;
- private AuthenticationListener listener;
+ private AuthXEventListener listener = AuthXEventListener.NOOP_LISTENER;
+ private List> postAuthenticateHooks = new ArrayList<>();
- public interface AuthenticationListener {
- public void onAuthenticate(Token token);
- }
-
- public JedisAuthXManager(TokenManager tokenManager) {
+ protected AuthXManager(TokenManager tokenManager) {
this.tokenManager = tokenManager;
}
- public JedisAuthXManager(TokenAuthConfig tokenAuthConfig) {
+ public AuthXManager(TokenAuthConfig tokenAuthConfig) {
this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(),
tokenAuthConfig.getTokenManagerConfig()));
}
@@ -53,7 +51,8 @@ public void onTokenRenewed(Token token) {
@Override
public void onError(Exception reason) {
- JedisAuthXManager.this.onError(reason);
+ listener.onIdentityProviderError(reason);
+ AuthXManager.this.onError(reason);
}
}, blockForInitialToken);
}
@@ -63,23 +62,18 @@ public void authenticateConnections(Token token) {
for (WeakReference connectionRef : connections) {
Connection connection = connectionRef.get();
if (connection != null) {
- try {
- connection.setCredentials(credentialsFromToken);
- } catch (Exception e) {
- log.error("Failed to authenticate connection!", e);
- }
+ connection.setCredentials(credentialsFromToken);
} else {
connections.remove(connectionRef);
}
}
- if (listener != null) {
- listener.onAuthenticate(token);
- }
+ postAuthenticateHooks.forEach(hook -> hook.accept(token));
}
public void onError(Exception reason) {
- throw new JedisAuthenticationException(
- "Token manager failed to acquire new token!", reason);
+ log.error("Token manager failed to acquire new token!", reason);
+ throw new JedisAuthenticationException("Token manager failed to acquire new token!",
+ reason);
}
public Connection addConnection(Connection connection) {
@@ -91,8 +85,22 @@ public void stop() {
tokenManager.stop();
}
- public void setListener(AuthenticationListener listener) {
- this.listener = listener;
+ public void setListener(AuthXEventListener listener) {
+ if (listener != null) {
+ this.listener = listener;
+ }
+ }
+
+ public void addPostAuthenticationHook(Consumer postAuthenticateHook) {
+ postAuthenticateHooks.add(postAuthenticateHook);
+ }
+
+ public void removePostAuthenticationHook(Consumer postAuthenticateHook) {
+ postAuthenticateHooks.remove(postAuthenticateHook);
+ }
+
+ public AuthXEventListener getListener() {
+ return listener;
}
@Override
diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java
index adc421e790..c70ab98720 100644
--- a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java
+++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java
@@ -1,6 +1,8 @@
package redis.clients.jedis.authentication;
-public class JedisAuthenticationException extends RuntimeException {
+import redis.clients.jedis.exceptions.JedisException;
+
+public class JedisAuthenticationException extends JedisException {
public JedisAuthenticationException(String message) {
super(message);
diff --git a/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java
new file mode 100644
index 0000000000..e0cde9cfef
--- /dev/null
+++ b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java
@@ -0,0 +1,112 @@
+package redis.clients.jedis.authentication;
+
+import java.io.ByteArrayInputStream;
+import java.security.KeyFactory;
+import java.security.PrivateKey;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.security.spec.PKCS8EncodedKeySpec;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.HashSet;
+import java.util.Set;
+
+public class EntraIDTestContext {
+ private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID";
+ private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY";
+ private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET";
+ private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY";
+ private static final String AZURE_CERT = "AZURE_CERT";
+ private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES";
+
+ private String clientId;
+ private String authority;
+ private String clientSecret;
+ private PrivateKey privateKey;
+ private X509Certificate cert;
+ private Set redisScopes;
+
+ public static final EntraIDTestContext DEFAULT = new EntraIDTestContext();
+
+ private EntraIDTestContext() {
+ clientId = System.getenv(AZURE_CLIENT_ID);
+ authority = System.getenv(AZURE_AUTHORITY);
+ clientSecret = System.getenv(AZURE_CLIENT_SECRET);
+ }
+
+ public EntraIDTestContext(String clientId, String authority, String clientSecret,
+ Set redisScopes) {
+ this.clientId = clientId;
+ this.authority = authority;
+ this.clientSecret = clientSecret;
+ this.redisScopes = redisScopes;
+ }
+
+ public String getClientId() {
+ return clientId;
+ }
+
+ public String getAuthority() {
+ return authority;
+ }
+
+ public String getClientSecret() {
+ return clientSecret;
+ }
+
+ public PrivateKey getPrivateKey() {
+ if (privateKey == null) {
+ this.privateKey = getPrivateKey(System.getenv(AZURE_PRIVATE_KEY));
+ }
+ return privateKey;
+ }
+
+ public X509Certificate getCert() {
+ if (cert == null) {
+ this.cert = getCert(System.getenv(AZURE_CERT));
+ }
+ return cert;
+ }
+
+ public Set getRedisScopes() {
+ if (redisScopes == null) {
+ String redisScopesEnv = System.getenv(AZURE_REDIS_SCOPES);
+ this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";")));
+ }
+ return redisScopes;
+ }
+
+ private PrivateKey getPrivateKey(String privateKey) {
+ try {
+ // Decode the base64 encoded key into a byte array
+ byte[] decodedKey = Base64.getDecoder().decode(privateKey);
+
+ // Generate the private key from the decoded byte array using PKCS8EncodedKeySpec
+ PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKey);
+ KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // Use the correct algorithm (e.g., "RSA", "EC", "DSA")
+ PrivateKey key = keyFactory.generatePrivate(keySpec);
+ return key;
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+
+ private X509Certificate getCert(String cert) {
+ try {
+ // Convert the Base64 encoded string into a byte array
+ byte[] encoded = java.util.Base64.getDecoder().decode(cert);
+
+ // Create a CertificateFactory for X.509 certificates
+ CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
+
+ // Generate the certificate from the byte array
+ X509Certificate certificate = (X509Certificate) certificateFactory
+ .generateCertificate(new ByteArrayInputStream(encoded));
+ return certificate;
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java
new file mode 100644
index 0000000000..4639e34b02
--- /dev/null
+++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java
@@ -0,0 +1,404 @@
+package redis.clients.jedis.authentication;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockConstruction;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+
+import org.awaitility.Awaitility;
+import org.awaitility.Durations;
+import org.junit.BeforeClass;
+import org.junit.FixMethodOrder;
+import org.junit.Test;
+import org.junit.runners.MethodSorters;
+import org.mockito.MockedConstruction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import redis.clients.authentication.core.IdentityProvider;
+import redis.clients.authentication.core.IdentityProviderConfig;
+import redis.clients.authentication.core.SimpleToken;
+import redis.clients.authentication.core.Token;
+import redis.clients.authentication.core.TokenAuthConfig;
+import redis.clients.authentication.entraid.EntraIDIdentityProvider;
+import redis.clients.authentication.entraid.EntraIDIdentityProviderConfig;
+import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder;
+import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType;
+import redis.clients.authentication.entraid.ServicePrincipalInfo;
+import redis.clients.jedis.Connection;
+import redis.clients.jedis.DefaultJedisClientConfig;
+import redis.clients.jedis.EndpointConfig;
+import redis.clients.jedis.HostAndPort;
+import redis.clients.jedis.HostAndPorts;
+import redis.clients.jedis.JedisPooled;
+import redis.clients.jedis.exceptions.JedisAccessControlException;
+import redis.clients.jedis.exceptions.JedisConnectionException;
+import redis.clients.jedis.scenario.FaultInjectionClient;
+
+@FixMethodOrder(MethodSorters.NAME_ASCENDING)
+public class RedisEntraIDIntegrationTests {
+ private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class);
+
+ private static EntraIDTestContext testCtx;
+ private static EndpointConfig endpointConfig;
+ private static HostAndPort hnp;
+
+ private final FaultInjectionClient faultClient = new FaultInjectionClient();
+
+ @BeforeClass
+ public static void before() {
+ try {
+ testCtx = EntraIDTestContext.DEFAULT;
+ endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl");
+ hnp = endpointConfig.getHostAndPort();
+ } catch (IllegalArgumentException e) {
+ log.warn("Skipping test because no Redis endpoint is configured");
+ org.junit.Assume.assumeTrue(false);
+ }
+ }
+
+ @Test
+ public void testJedisConfig() {
+ AtomicInteger counter = new AtomicInteger(0);
+ try (MockedConstruction mockedConstructor = mockConstruction(
+ EntraIDIdentityProvider.class, (mock, context) -> {
+ ServicePrincipalInfo info = (ServicePrincipalInfo) context.arguments().get(0);
+
+ assertEquals(testCtx.getClientId(), info.getClientId());
+ assertEquals(testCtx.getAuthority(), info.getAuthority());
+ assertEquals(testCtx.getClientSecret(), info.getSecret());
+ assertEquals(testCtx.getRedisScopes(), context.arguments().get(1));
+ assertNotNull(mock);
+ doAnswer(invocation -> {
+ counter.incrementAndGet();
+ return new SimpleToken("token1", System.currentTimeMillis() + 5 * 60 * 1000,
+ System.currentTimeMillis(), Collections.singletonMap("oid", "default"));
+ }).when(mock).requestToken();
+ })) {
+
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .authority(testCtx.getAuthority()).clientId(testCtx.getClientId())
+ .secret(testCtx.getClientSecret()).scopes(testCtx.getRedisScopes()).build();
+
+ DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
+ .authXManager(new AuthXManager(tokenAuthConfig)).build();
+
+ JedisPooled jedis = new JedisPooled(new HostAndPort("localhost", 6379), jedisConfig);
+ assertNotNull(jedis);
+ assertEquals(1, counter.get());
+
+ }
+ }
+
+ // T.1.1
+ // Verify authentication using Azure AD with managed identities
+ // @Test
+ public void withUserAssignedId_azureManagedIdentityIntegrationTest() {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId())
+ .userAssignedManagedIdentity(UserManagedIdentityType.CLIENT_ID, "userManagedAuthxId")
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
+
+ DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
+ .authXManager(new AuthXManager(tokenAuthConfig)).build();
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ }
+
+ // T.1.1
+ // Verify authentication using Azure AD with managed identities
+ // @Test
+ public void withSystemAssignedId_azureManagedIdentityIntegrationTest() {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId()).systemAssignedManagedIdentity()
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
+
+ DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
+ .authXManager(new AuthXManager(tokenAuthConfig)).build();
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ }
+
+ // T.1.1
+ // Verify authentication using Azure AD with service principals
+ @Test
+ public void withSecret_azureServicePrincipalIntegrationTest() {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
+
+ DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
+ .authXManager(new AuthXManager(tokenAuthConfig)).build();
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ }
+
+ // T.1.1
+ // Verify authentication using Azure AD with service principals
+ @Test
+ public void withCertificate_azureServicePrincipalIntegrationTest() {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
+
+ DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
+ .authXManager(new AuthXManager(tokenAuthConfig)).build();
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ }
+
+ // T.2.2
+ // Test that the Redis client is not blocked/interrupted during token renewal.
+ @Test
+ public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException {
+ // set the stage with consecutive get/set operations with unique keys which takes at least for 2000 ms with a jedispooled instace,
+ // configure token manager to renew token approximately every 100ms
+ // wait till all operations are completed and verify that token was renewed at least 20 times after initial token acquisition
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
+ .expirationRefreshRatio(0.000001F).build();
+
+ AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
+ Consumer hook = mock(Consumer.class);
+ authXManager.addPostAuthenticationHook(hook);
+
+ DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
+ .authXManager(authXManager).build();
+
+ long startTime = System.currentTimeMillis();
+ List> futures = new ArrayList<>();
+ ExecutorService executor = Executors.newFixedThreadPool(5);
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
+ for (int i = 0; i < 5; i++) {
+ Future> future = executor.submit(() -> {
+ for (; System.currentTimeMillis() - startTime < 2000;) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ });
+ futures.add(future);
+ }
+ for (Future> task : futures) {
+ task.get();
+ }
+
+ verify(hook, atLeast(20)).accept(any());
+ executor.shutdown();
+ }
+ }
+
+ // T.3.2
+ // Verify that all existing connections can be re-authenticated when a new token is received.
+ @Test
+ public void allConnectionsReauthTest() throws InterruptedException, ExecutionException {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
+ .expirationRefreshRatio(0.000001F).build();
+
+ AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
+ authXManager = spy(authXManager);
+
+ List connections = new ArrayList<>();
+
+ doAnswer(invocation -> {
+ Connection connection = spy((Connection) invocation.getArgument(0));
+ invocation.getArguments()[0] = connection;
+ connections.add(connection);
+ Object result = invocation.callRealMethod();
+ return result;
+ }).when(authXManager).addConnection(any(Connection.class));
+
+ DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
+ .authXManager(authXManager).build();
+
+ long startTime = System.currentTimeMillis();
+ List> futures = new ArrayList<>();
+ ExecutorService executor = Executors.newFixedThreadPool(5);
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
+ for (int i = 0; i < 5; i++) {
+ Future> future = executor.submit(() -> {
+ for (; System.currentTimeMillis() - startTime < 2000;) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ });
+ futures.add(future);
+ }
+ for (Future> task : futures) {
+ task.get();
+ }
+
+ connections.forEach(conn -> {
+ verify(conn, atLeast(1)).reAuth();
+ });
+ executor.shutdown();
+ }
+ }
+
+ // T.3.2
+ // Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them
+ @Test
+ public void partialReauthFailureTest() {
+
+ }
+
+ // T.3.3
+ // Verify behavior when attempting to authenticate a single connection with an expired token.
+ @Test
+ public void connectionAuthWithExpiredTokenTest() {
+ IdentityProvider idp = new EntraIDIdentityProviderConfig(
+ new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(),
+ testCtx.getAuthority()),
+ testCtx.getRedisScopes()).getProvider();
+
+ IdentityProvider mockIdentityProvider = mock(IdentityProvider.class);
+ AtomicReference token = new AtomicReference<>();
+ doAnswer(invocation -> {
+ if (token.get() == null) {
+ token.set(idp.requestToken());
+ }
+ return token.get();
+ }).when(mockIdentityProvider).requestToken();
+ IdentityProviderConfig idpConfig = mock(IdentityProviderConfig.class);
+ when(idpConfig.getProvider()).thenReturn(mockIdentityProvider);
+
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .identityProviderConfig(idpConfig).expirationRefreshRatio(0.000001F).build();
+ AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
+ DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
+ .authXManager(authXManager).build();
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
+ for (int i = 0; i < 50; i++) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+
+ token
+ .set(new SimpleToken("token1", System.currentTimeMillis() - 1, System.currentTimeMillis(),
+ Collections.singletonMap("oid", idp.requestToken().tryGet("oid"))));
+
+ JedisAccessControlException aclException = assertThrows(JedisAccessControlException.class,
+ () -> {
+ for (int i = 0; i < 50; i++) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ });
+
+ assertEquals("WRONGPASS invalid username-password pair", aclException.getMessage());
+ }
+ }
+
+ // T.3.4
+ // Verify handling of reconnection and re-authentication after a network partition. (use cached token)
+ // @Test
+ public void networkPartitionEvictionTest() {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
+ .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
+ .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
+ .expirationRefreshRatio(0.5F).build();
+ AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
+ DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
+ .authXManager(authXManager).build();
+
+ try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
+ for (int i = 0; i < 5; i++) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+
+ triggerNetworkFailure();
+
+ JedisConnectionException aclException = assertThrows(JedisConnectionException.class, () -> {
+ for (int i = 0; i < 50; i++) {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ }
+ });
+
+ assertEquals("Unexpected end of stream.", aclException.getMessage());
+ Awaitility.await().pollDelay(Durations.ONE_HUNDRED_MILLISECONDS).atMost(Durations.TWO_SECONDS)
+ .until(() -> {
+ String key = UUID.randomUUID().toString();
+ jedis.set(key, "value");
+ assertEquals("value", jedis.get(key));
+ jedis.del(key);
+ return true;
+ });
+ }
+ }
+
+ private void triggerNetworkFailure() {
+ HashMap params = new HashMap<>();
+ params.put("bdb_id", endpointConfig.getBdbId());
+
+ FaultInjectionClient.TriggerActionResponse actionResponse = null;
+ String action = "network_failure";
+ try {
+ log.info("Triggering {}", action);
+ actionResponse = faultClient.triggerAction(action, params);
+ } catch (IOException e) {
+ fail("Fault Injection Server error:" + e.getMessage());
+ }
+ log.info("Action id: {}", actionResponse.getActionId());
+ }
+}
diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java
index 4cbf155dd3..780c82c781 100644
--- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java
+++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java
@@ -9,15 +9,17 @@
import java.util.Arrays;
import java.util.Collections;
-import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
+import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.IdentityProviderConfig;
@@ -36,52 +38,65 @@
import redis.clients.jedis.commands.ProtocolCommand;
public class TokenBasedAuthenticationIntegrationTests {
+ private static final Logger log = LoggerFactory
+ .getLogger(TokenBasedAuthenticationIntegrationTests.class);
+
+ private static EndpointConfig endpointConfig;
+
+ @BeforeClass
+ public static void before() {
+ try {
+ endpointConfig = HostAndPorts.getRedisEndpoint("standalone0");
+ } catch (IllegalArgumentException e) {
+ try {
+ endpointConfig = HostAndPorts.getRedisEndpoint("standalone");
+ } catch (IllegalArgumentException ex) {
+ log.warn("Skipping test because no Redis endpoint is configured");
+ org.junit.Assume.assumeTrue(false);
+ }
+ }
+ }
- protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0");
-
- @Test
- public void testJedisPooledAuth() {
- String user = "default";
- String password = endpoint.getPassword();
+ @Test
+ public void testJedisPooledForInitialAuth() {
+ String user = "default";
+ String password = endpointConfig.getPassword();
- IdentityProvider idProvider = mock(IdentityProvider.class);
- when(idProvider.requestToken())
- .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000,
- System.currentTimeMillis(), Collections.singletonMap("oid", user)));
+ IdentityProvider idProvider = mock(IdentityProvider.class);
+ when(idProvider.requestToken())
+ .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000,
+ System.currentTimeMillis(), Collections.singletonMap("oid", user)));
- IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
- when(idProviderConfig.getProvider()).thenReturn(idProvider);
+ IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
+ when(idProviderConfig.getProvider()).thenReturn(idProvider);
- TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
- .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
- .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build();
+ TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
+ .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
+ .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build();
- JedisClientConfig clientConfig = DefaultJedisClientConfig.builder()
- .tokenAuthConfig(tokenAuthConfig).build();
+ JedisClientConfig clientConfig = DefaultJedisClientConfig.builder()
+ .authXManager(new AuthXManager(tokenAuthConfig)).build();
- try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) {
- ArgumentCaptor captor = ArgumentCaptor
- .forClass(CommandArguments.class);
+ try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) {
+ ArgumentCaptor captor = ArgumentCaptor.forClass(CommandArguments.class);
- try (JedisPooled jedis = new JedisPooled(endpoint.getHostAndPort(), clientConfig)) {
- jedis.get("key1");
- }
+ try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) {
+ jedis.get("key1");
+ }
- // Verify that the static method was called
- mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()),
- Mockito.atLeast(4));
+ // Verify that the static method was called
+ mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), Mockito.atLeast(4));
- CommandArguments commandArgs = captor.getAllValues().get(0);
- List args = StreamSupport.stream(commandArgs.spliterator(), false)
- .map(Rawable::getRaw).collect(Collectors.toList());
+ CommandArguments commandArgs = captor.getAllValues().get(0);
+ List args = StreamSupport.stream(commandArgs.spliterator(), false)
+ .map(Rawable::getRaw).collect(Collectors.toList());
- assertThat(args,
- contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes()));
+ assertThat(args,
+ contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes()));
- List cmds = captor.getAllValues().stream()
- .map(item -> item.getCommand()).collect(Collectors.toList());
- assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET),
- cmds);
- }
+ List cmds = captor.getAllValues().stream().map(item -> item.getCommand())
+ .collect(Collectors.toList());
+ assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), cmds);
}
+ }
}
diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java
index 7ac68361aa..83c441d492 100644
--- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java
+++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java
@@ -22,6 +22,7 @@
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
+import org.awaitility.Durations;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
@@ -36,12 +37,15 @@
import redis.clients.authentication.core.TokenListener;
import redis.clients.authentication.core.TokenManager;
import redis.clients.authentication.core.TokenManagerConfig;
+import redis.clients.authentication.core.TokenRequestException;
import redis.clients.jedis.ConnectionPool;
import redis.clients.jedis.EndpointConfig;
-import redis.clients.jedis.HostAndPorts;
+import redis.clients.jedis.HostAndPort;
public class TokenBasedAuthenticationUnitTests {
- protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0");
+
+ private HostAndPort hnp = new HostAndPort("localhost", 6379);
+ private EndpointConfig endpoint = new EndpointConfig(hnp, null, null, false);
@Test
public void testJedisAuthXManagerInstance() {
@@ -57,25 +61,51 @@ public void testJedisAuthXManagerInstance() {
assertEquals(tokenManagerConfig, context.arguments().get(1));
})) {
- new JedisAuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig));
+ new AuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig));
}
}
@Test
- public void testJedisAuthXManagerTriggersEvict() throws Exception {
+ public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() throws Exception {
IdentityProvider idProvider = mock(IdentityProvider.class);
when(idProvider.requestToken())
- .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 100000,
+ .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000,
System.currentTimeMillis(), Collections.singletonMap("oid", "default")));
TokenManager tokenManager = new TokenManager(idProvider,
- new TokenManagerConfig(0.5F, 1000, 1000, null));
- JedisAuthXManager jedisAuthXManager = new JedisAuthXManager(tokenManager);
+ new TokenManagerConfig(0.4F, 100, 1000, null));
+ AuthXManager jedisAuthXManager = new AuthXManager(tokenManager);
+
+ AtomicInteger numberOfEvictions = new AtomicInteger(0);
+ ConnectionPool pool = new ConnectionPool(hnp,
+ endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) {
+ @Override
+ public void evict() throws Exception {
+ numberOfEvictions.incrementAndGet();
+ super.evict();
+ }
+ };
+
+ await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS)
+ .atMost(Durations.FIVE_HUNDRED_MILLISECONDS)
+ .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1));
+ }
+
+ public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws Exception {
+
+ IdentityProvider idProvider = mock(IdentityProvider.class);
+ when(idProvider.requestToken())
+ .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000,
+ System.currentTimeMillis(), Collections.singletonMap("oid", "default")));
+
+ TokenManager tokenManager = new TokenManager(idProvider,
+ new TokenManagerConfig(0.9F, 600, 1000, null));
+ AuthXManager jedisAuthXManager = new AuthXManager(tokenManager);
AtomicInteger numberOfEvictions = new AtomicInteger(0);
ConnectionPool pool = new ConnectionPool(endpoint.getHostAndPort(),
- endpoint.getClientConfigBuilder().build(), jedisAuthXManager) {
+ endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) {
@Override
public void evict() throws Exception {
numberOfEvictions.incrementAndGet();
@@ -83,8 +113,9 @@ public void evict() throws Exception {
}
};
- jedisAuthXManager.start(true);
- assertEquals(1, numberOfEvictions.get());
+ await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS)
+ .atMost(Durations.FIVE_HUNDRED_MILLISECONDS)
+ .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1));
}
public static class TokenManagerConfigWrapper extends TokenManagerConfig {
@@ -190,7 +221,7 @@ public void testAuthXManagerReceivesNewToken()
TokenManager tokenManager = new TokenManager(identityProvider,
new TokenManagerConfig(0.7F, 200, 2000, null));
- JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager));
+ AuthXManager manager = spy(new AuthXManager(tokenManager));
final Token[] tokenHolder = new Token[1];
doAnswer(invocation -> {
@@ -213,11 +244,10 @@ public void testBlockForInitialToken() {
TokenManager tokenManager = new TokenManager(identityProvider,
new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100)));
- JedisAuthXManager manager = new JedisAuthXManager(tokenManager);
- ExecutionException e = assertThrows(ExecutionException.class, () -> manager.start(true));
+ AuthXManager manager = new AuthXManager(tokenManager);
+ TokenRequestException e = assertThrows(TokenRequestException.class, () -> manager.start(true));
- assertEquals(exceptionMessage,
- e.getCause().getCause().getCause().getCause().getMessage());
+ assertEquals(exceptionMessage, e.getCause().getCause().getCause().getMessage());
}
@Test
@@ -231,9 +261,9 @@ public void testNoBlockForInitialToken()
};
TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200,
- 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 100)));
+ 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 0)));
- JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager));
+ AuthXManager manager = spy(new AuthXManager(tokenManager));
manager.start(false);
requesLatch.await();
@@ -296,7 +326,7 @@ public void testTokenManagerWithHangingTokenRequest()
TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200,
executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100)));
- JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager));
+ AuthXManager manager = spy(new AuthXManager(tokenManager));
manager.start(false);
requesLatch.await();
verify(manager, never()).onError(any());