diff --git a/pom.xml b/pom.xml index a8930f2306..4ca1343f19 100644 --- a/pom.xml +++ b/pom.xml @@ -156,6 +156,13 @@ test + + redis.clients.authentication + redis-authx-entraid + 0.1.0-SNAPSHOT + test + + io.github.resilience4j diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index ce4a10cb7b..cdbe2ab5c7 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -11,7 +11,8 @@ import java.util.function.Supplier; import redis.clients.jedis.annots.Experimental; -import redis.clients.jedis.authentication.JedisAuthXManager; +import redis.clients.jedis.authentication.AuthXManager; +import redis.clients.jedis.authentication.AuthXEventListener; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.csc.CacheConnection; import redis.clients.jedis.exceptions.JedisException; @@ -28,41 +29,43 @@ public class ConnectionFactory implements PooledObjectFactory { private final Cache clientSideCache; private final Supplier objectMaker; + private final AuthXEventListener authenticationEventListener; + public ConnectionFactory(final HostAndPort hostAndPort) { - this(hostAndPort, DefaultJedisClientConfig.builder().build(), null, null); + this(hostAndPort, DefaultJedisClientConfig.builder().build(), null); } public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) { - this(hostAndPort, clientConfig, null, null); + this(hostAndPort, clientConfig, null); } @Experimental public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, - Cache csCache, JedisAuthXManager authXManager) { - this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache, - authXManager); + Cache csCache) { + this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache); } public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) { - this(jedisSocketFactory, clientConfig, null, null); + this(jedisSocketFactory, clientConfig, null); } private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, - final JedisClientConfig clientConfig, Cache csCache, JedisAuthXManager authXManager) { + final JedisClientConfig clientConfig, Cache csCache) { this.jedisSocketFactory = jedisSocketFactory; this.clientSideCache = csCache; + AuthXManager authXManager = clientConfig.getAuthXManager(); if (authXManager == null) { this.clientConfig = clientConfig; this.objectMaker = connectionSupplier(); + this.authenticationEventListener = AuthXEventListener.NOOP_LISTENER; } else { - this.clientConfig = replaceCredentialsProvider(clientConfig, - authXManager); + this.clientConfig = replaceCredentialsProvider(clientConfig, authXManager); Supplier supplier = connectionSupplier(); this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get()); - + this.authenticationEventListener = authXManager.getListener(); try { authXManager.start(true); } catch (InterruptedException | ExecutionException | TimeoutException e) { @@ -114,7 +117,12 @@ public PooledObject makeObject() throws Exception { public void passivateObject(PooledObject pooledConnection) throws Exception { // TODO maybe should select db 0? Not sure right now. Connection jedis = pooledConnection.getObject(); - jedis.reAuth(); + try { + jedis.reAuth(); + } catch (Exception e) { + authenticationEventListener.onConnectionAuthenticationError(e); + throw e; + } } @Override @@ -122,7 +130,12 @@ public boolean validateObject(PooledObject pooledConnection) { final Connection jedis = pooledConnection.getObject(); try { // check HostAndPort ?? - jedis.reAuth(); + try { + jedis.reAuth(); + } catch (Exception e) { + authenticationEventListener.onConnectionAuthenticationError(e); + throw e; + } return jedis.isConnected() && jedis.ping(); } catch (final Exception e) { logger.warn("Error while validating pooled Connection object.", e); diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index d7dc0d85f7..536b3a6484 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -4,36 +4,25 @@ import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import redis.clients.jedis.annots.Experimental; -import redis.clients.jedis.authentication.JedisAuthXManager; +import redis.clients.jedis.authentication.AuthXManager; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.Pool; public class ConnectionPool extends Pool { - private JedisAuthXManager authXManager; + private AuthXManager authXManager; public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { - this(hostAndPort, clientConfig, createAuthXManager(clientConfig)); - } - - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, - JedisAuthXManager authXManager) { - this(new ConnectionFactory(hostAndPort, clientConfig, null, authXManager)); - attachAuthenticationListener(authXManager); + this(new ConnectionFactory(hostAndPort, clientConfig)); + attachAuthenticationListener(clientConfig.getAuthXManager()); } @Experimental public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) { - this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig)); - } - - @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, - Cache clientSideCache, JedisAuthXManager authXManager) { - this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager)); - attachAuthenticationListener(authXManager); + this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache)); + attachAuthenticationListener(clientConfig.getAuthXManager()); } public ConnectionPool(PooledObjectFactory factory) { @@ -42,22 +31,15 @@ public ConnectionPool(PooledObjectFactory factory) { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, GenericObjectPoolConfig poolConfig) { - this(hostAndPort, clientConfig, null, createAuthXManager(clientConfig), poolConfig); + this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig); + attachAuthenticationListener(clientConfig.getAuthXManager()); } @Experimental public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache, GenericObjectPoolConfig poolConfig) { - this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig), poolConfig); - } - - @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, - Cache clientSideCache, JedisAuthXManager authXManager, - GenericObjectPoolConfig poolConfig) { - this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager), - poolConfig); - attachAuthenticationListener(authXManager); + this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig); + attachAuthenticationListener(clientConfig.getAuthXManager()); } public ConnectionPool(PooledObjectFactory factory, @@ -80,17 +62,10 @@ public void close() { super.close(); } - private static JedisAuthXManager createAuthXManager(JedisClientConfig config) { - if (config.getTokenAuthConfig() != null) { - return new JedisAuthXManager(config.getTokenAuthConfig()); - } - return null; - } - - private void attachAuthenticationListener(JedisAuthXManager authXManager) { + private void attachAuthenticationListener(AuthXManager authXManager) { this.authXManager = authXManager; if (authXManager != null) { - authXManager.setListener(token -> { + authXManager.addPostAuthenticationHook(token -> { try { // this is to trigger validations on each connection via ConnectionFactory evict(); diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java index 8b161ca7ff..5f0e050ef4 100644 --- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java @@ -5,7 +5,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; -import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.jedis.authentication.AuthXManager; public final class DefaultJedisClientConfig implements JedisClientConfig { @@ -30,7 +30,7 @@ public final class DefaultJedisClientConfig implements JedisClientConfig { private final boolean readOnlyForRedisClusterReplicas; - private final TokenAuthConfig tokenAuthConfig; + private final AuthXManager authXManager; private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMillis, int soTimeoutMillis, int blockingSocketTimeoutMillis, @@ -38,7 +38,7 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper, ClientSetInfoConfig clientSetInfoConfig, boolean readOnlyForRedisClusterReplicas, - TokenAuthConfig tokenAuthConfig) { + AuthXManager authXManager) { this.redisProtocol = protocol; this.connectionTimeoutMillis = connectionTimeoutMillis; this.socketTimeoutMillis = soTimeoutMillis; @@ -53,7 +53,8 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi this.hostAndPortMapper = hostAndPortMapper; this.clientSetInfoConfig = clientSetInfoConfig; this.readOnlyForRedisClusterReplicas = readOnlyForRedisClusterReplicas; - this.tokenAuthConfig = tokenAuthConfig; + this.authXManager = authXManager; + } @Override @@ -93,8 +94,8 @@ public Supplier getCredentialsProvider() { } @Override - public TokenAuthConfig getTokenAuthConfig() { - return tokenAuthConfig; + public AuthXManager getAuthXManager() { + return authXManager; } @Override @@ -171,7 +172,7 @@ public static class Builder { private boolean readOnlyForRedisClusterReplicas = false; - private TokenAuthConfig tokenAuthConfig = null; + private AuthXManager authXManager; private Builder() { } @@ -185,7 +186,7 @@ public DefaultJedisClientConfig build() { return new DefaultJedisClientConfig(redisProtocol, connectionTimeoutMillis, socketTimeoutMillis, blockingSocketTimeoutMillis, credentialsProvider, database, clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, - clientSetInfoConfig, readOnlyForRedisClusterReplicas, tokenAuthConfig); + clientSetInfoConfig, readOnlyForRedisClusterReplicas, authXManager); } /** @@ -287,8 +288,8 @@ public Builder readOnlyForRedisClusterReplicas() { return this; } - public Builder tokenAuthConfig(TokenAuthConfig tokenAuthConfig) { - this.tokenAuthConfig = tokenAuthConfig; + public Builder authXManager(AuthXManager authXManager) { + this.authXManager = authXManager; return this; } @@ -307,7 +308,7 @@ public Builder from(JedisClientConfig instance) { this.hostAndPortMapper = instance.getHostAndPortMapper(); this.clientSetInfoConfig = instance.getClientSetInfoConfig(); this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas(); - this.tokenAuthConfig = instance.getTokenAuthConfig(); + this.authXManager = instance.getAuthXManager(); return this; } } @@ -316,12 +317,12 @@ public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int s int blockingSocketTimeoutMillis, String user, String password, int database, String clientName, boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, HostnameVerifier hostnameVerifier, - HostAndPortMapper hostAndPortMapper, TokenAuthConfig tokenAuthConfig) { + HostAndPortMapper hostAndPortMapper, AuthXManager authXManager) { return new DefaultJedisClientConfig(null, connectionTimeoutMillis, soTimeoutMillis, blockingSocketTimeoutMillis, new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(user, password)), database, clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, null, - false, tokenAuthConfig); + false, authXManager); } public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { @@ -330,6 +331,6 @@ public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { copy.getCredentialsProvider(), copy.getDatabase(), copy.getClientName(), copy.isSsl(), copy.getSslSocketFactory(), copy.getSslParameters(), copy.getHostnameVerifier(), copy.getHostAndPortMapper(), copy.getClientSetInfoConfig(), - copy.isReadOnlyForRedisClusterReplicas(), copy.getTokenAuthConfig()); + copy.isReadOnlyForRedisClusterReplicas(), copy.getAuthXManager()); } } diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java index a8046694bf..82e9eb8e7f 100644 --- a/src/main/java/redis/clients/jedis/JedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java @@ -5,7 +5,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; -import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.jedis.authentication.AuthXManager; public interface JedisClientConfig { @@ -47,10 +47,11 @@ default String getPassword() { } default Supplier getCredentialsProvider() { - return new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(getUser(), getPassword())); + return new DefaultRedisCredentialsProvider( + new DefaultRedisCredentials(getUser(), getPassword())); } - default TokenAuthConfig getTokenAuthConfig() { + default AuthXManager getAuthXManager() { return null; } diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java new file mode 100644 index 0000000000..4750404157 --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java @@ -0,0 +1,21 @@ +package redis.clients.jedis.authentication; + +public interface AuthXEventListener { + + static AuthXEventListener NOOP_LISTENER = new AuthXEventListener() { + + @Override + public void onIdentityProviderError(Exception reason) { + } + + @Override + public void onConnectionAuthenticationError(Exception reason) { + } + + }; + + public void onIdentityProviderError(Exception reason); + + public void onConnectionAuthenticationError(Exception reason); + +} diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java similarity index 63% rename from src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java rename to src/main/java/redis/clients/jedis/authentication/AuthXManager.java index a210f7fb42..d66bddeb4c 100644 --- a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; +import java.util.function.Consumer; import java.util.function.Supplier; import org.slf4j.Logger; @@ -18,25 +19,22 @@ import redis.clients.jedis.Connection; import redis.clients.jedis.RedisCredentials; -public class JedisAuthXManager implements Supplier { +public final class AuthXManager implements Supplier { - private static final Logger log = LoggerFactory.getLogger(JedisAuthXManager.class); + private static final Logger log = LoggerFactory.getLogger(AuthXManager.class); private TokenManager tokenManager; private List> connections = Collections .synchronizedList(new ArrayList<>()); private Token currentToken; - private AuthenticationListener listener; + private AuthXEventListener listener = AuthXEventListener.NOOP_LISTENER; + private List> postAuthenticateHooks = new ArrayList<>(); - public interface AuthenticationListener { - public void onAuthenticate(Token token); - } - - public JedisAuthXManager(TokenManager tokenManager) { + protected AuthXManager(TokenManager tokenManager) { this.tokenManager = tokenManager; } - public JedisAuthXManager(TokenAuthConfig tokenAuthConfig) { + public AuthXManager(TokenAuthConfig tokenAuthConfig) { this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(), tokenAuthConfig.getTokenManagerConfig())); } @@ -53,7 +51,8 @@ public void onTokenRenewed(Token token) { @Override public void onError(Exception reason) { - JedisAuthXManager.this.onError(reason); + listener.onIdentityProviderError(reason); + AuthXManager.this.onError(reason); } }, blockForInitialToken); } @@ -63,23 +62,18 @@ public void authenticateConnections(Token token) { for (WeakReference connectionRef : connections) { Connection connection = connectionRef.get(); if (connection != null) { - try { - connection.setCredentials(credentialsFromToken); - } catch (Exception e) { - log.error("Failed to authenticate connection!", e); - } + connection.setCredentials(credentialsFromToken); } else { connections.remove(connectionRef); } } - if (listener != null) { - listener.onAuthenticate(token); - } + postAuthenticateHooks.forEach(hook -> hook.accept(token)); } public void onError(Exception reason) { - throw new JedisAuthenticationException( - "Token manager failed to acquire new token!", reason); + log.error("Token manager failed to acquire new token!", reason); + throw new JedisAuthenticationException("Token manager failed to acquire new token!", + reason); } public Connection addConnection(Connection connection) { @@ -91,8 +85,22 @@ public void stop() { tokenManager.stop(); } - public void setListener(AuthenticationListener listener) { - this.listener = listener; + public void setListener(AuthXEventListener listener) { + if (listener != null) { + this.listener = listener; + } + } + + public void addPostAuthenticationHook(Consumer postAuthenticateHook) { + postAuthenticateHooks.add(postAuthenticateHook); + } + + public void removePostAuthenticationHook(Consumer postAuthenticateHook) { + postAuthenticateHooks.remove(postAuthenticateHook); + } + + public AuthXEventListener getListener() { + return listener; } @Override diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java index adc421e790..c70ab98720 100644 --- a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java @@ -1,6 +1,8 @@ package redis.clients.jedis.authentication; -public class JedisAuthenticationException extends RuntimeException { +import redis.clients.jedis.exceptions.JedisException; + +public class JedisAuthenticationException extends JedisException { public JedisAuthenticationException(String message) { super(message); diff --git a/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java new file mode 100644 index 0000000000..e0cde9cfef --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java @@ -0,0 +1,112 @@ +package redis.clients.jedis.authentication; + +import java.io.ByteArrayInputStream; +import java.security.KeyFactory; +import java.security.PrivateKey; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashSet; +import java.util.Set; + +public class EntraIDTestContext { + private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID"; + private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY"; + private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"; + private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY"; + private static final String AZURE_CERT = "AZURE_CERT"; + private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES"; + + private String clientId; + private String authority; + private String clientSecret; + private PrivateKey privateKey; + private X509Certificate cert; + private Set redisScopes; + + public static final EntraIDTestContext DEFAULT = new EntraIDTestContext(); + + private EntraIDTestContext() { + clientId = System.getenv(AZURE_CLIENT_ID); + authority = System.getenv(AZURE_AUTHORITY); + clientSecret = System.getenv(AZURE_CLIENT_SECRET); + } + + public EntraIDTestContext(String clientId, String authority, String clientSecret, + Set redisScopes) { + this.clientId = clientId; + this.authority = authority; + this.clientSecret = clientSecret; + this.redisScopes = redisScopes; + } + + public String getClientId() { + return clientId; + } + + public String getAuthority() { + return authority; + } + + public String getClientSecret() { + return clientSecret; + } + + public PrivateKey getPrivateKey() { + if (privateKey == null) { + this.privateKey = getPrivateKey(System.getenv(AZURE_PRIVATE_KEY)); + } + return privateKey; + } + + public X509Certificate getCert() { + if (cert == null) { + this.cert = getCert(System.getenv(AZURE_CERT)); + } + return cert; + } + + public Set getRedisScopes() { + if (redisScopes == null) { + String redisScopesEnv = System.getenv(AZURE_REDIS_SCOPES); + this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";"))); + } + return redisScopes; + } + + private PrivateKey getPrivateKey(String privateKey) { + try { + // Decode the base64 encoded key into a byte array + byte[] decodedKey = Base64.getDecoder().decode(privateKey); + + // Generate the private key from the decoded byte array using PKCS8EncodedKeySpec + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKey); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // Use the correct algorithm (e.g., "RSA", "EC", "DSA") + PrivateKey key = keyFactory.generatePrivate(keySpec); + return key; + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private X509Certificate getCert(String cert) { + try { + // Convert the Base64 encoded string into a byte array + byte[] encoded = java.util.Base64.getDecoder().decode(cert); + + // Create a CertificateFactory for X.509 certificates + CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + + // Generate the certificate from the byte array + X509Certificate certificate = (X509Certificate) certificateFactory + .generateCertificate(new ByteArrayInputStream(encoded)); + return certificate; + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java new file mode 100644 index 0000000000..4639e34b02 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -0,0 +1,404 @@ +package redis.clients.jedis.authentication; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.awaitility.Awaitility; +import org.awaitility.Durations; +import org.junit.BeforeClass; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; +import org.mockito.MockedConstruction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.entraid.EntraIDIdentityProvider; +import redis.clients.authentication.entraid.EntraIDIdentityProviderConfig; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType; +import redis.clients.authentication.entraid.ServicePrincipalInfo; +import redis.clients.jedis.Connection; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.exceptions.JedisAccessControlException; +import redis.clients.jedis.exceptions.JedisConnectionException; +import redis.clients.jedis.scenario.FaultInjectionClient; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class RedisEntraIDIntegrationTests { + private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class); + + private static EntraIDTestContext testCtx; + private static EndpointConfig endpointConfig; + private static HostAndPort hnp; + + private final FaultInjectionClient faultClient = new FaultInjectionClient(); + + @BeforeClass + public static void before() { + try { + testCtx = EntraIDTestContext.DEFAULT; + endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl"); + hnp = endpointConfig.getHostAndPort(); + } catch (IllegalArgumentException e) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + + @Test + public void testJedisConfig() { + AtomicInteger counter = new AtomicInteger(0); + try (MockedConstruction mockedConstructor = mockConstruction( + EntraIDIdentityProvider.class, (mock, context) -> { + ServicePrincipalInfo info = (ServicePrincipalInfo) context.arguments().get(0); + + assertEquals(testCtx.getClientId(), info.getClientId()); + assertEquals(testCtx.getAuthority(), info.getAuthority()); + assertEquals(testCtx.getClientSecret(), info.getSecret()); + assertEquals(testCtx.getRedisScopes(), context.arguments().get(1)); + assertNotNull(mock); + doAnswer(invocation -> { + counter.incrementAndGet(); + return new SimpleToken("token1", System.currentTimeMillis() + 5 * 60 * 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default")); + }).when(mock).requestToken(); + })) { + + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .authority(testCtx.getAuthority()).clientId(testCtx.getClientId()) + .secret(testCtx.getClientSecret()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + JedisPooled jedis = new JedisPooled(new HostAndPort("localhost", 6379), jedisConfig); + assertNotNull(jedis); + assertEquals(1, counter.get()); + + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + // @Test + public void withUserAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()) + .userAssignedManagedIdentity(UserManagedIdentityType.CLIENT_ID, "userManagedAuthxId") + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + // @Test + public void withSystemAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).systemAssignedManagedIdentity() + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals + @Test + public void withSecret_azureServicePrincipalIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals + @Test + public void withCertificate_azureServicePrincipalIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.2.2 + // Test that the Redis client is not blocked/interrupted during token renewal. + @Test + public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException { + // set the stage with consecutive get/set operations with unique keys which takes at least for 2000 ms with a jedispooled instace, + // configure token manager to renew token approximately every 100ms + // wait till all operations are completed and verify that token was renewed at least 20 times after initial token acquisition + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.000001F).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + Consumer hook = mock(Consumer.class); + authXManager.addPostAuthenticationHook(hook); + + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + long startTime = System.currentTimeMillis(); + List> futures = new ArrayList<>(); + ExecutorService executor = Executors.newFixedThreadPool(5); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + Future future = executor.submit(() -> { + for (; System.currentTimeMillis() - startTime < 2000;) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + task.get(); + } + + verify(hook, atLeast(20)).accept(any()); + executor.shutdown(); + } + } + + // T.3.2 + // Verify that all existing connections can be re-authenticated when a new token is received. + @Test + public void allConnectionsReauthTest() throws InterruptedException, ExecutionException { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.000001F).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + + List connections = new ArrayList<>(); + + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + long startTime = System.currentTimeMillis(); + List> futures = new ArrayList<>(); + ExecutorService executor = Executors.newFixedThreadPool(5); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + Future future = executor.submit(() -> { + for (; System.currentTimeMillis() - startTime < 2000;) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + task.get(); + } + + connections.forEach(conn -> { + verify(conn, atLeast(1)).reAuth(); + }); + executor.shutdown(); + } + } + + // T.3.2 + // Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them + @Test + public void partialReauthFailureTest() { + + } + + // T.3.3 + // Verify behavior when attempting to authenticate a single connection with an expired token. + @Test + public void connectionAuthWithExpiredTokenTest() { + IdentityProvider idp = new EntraIDIdentityProviderConfig( + new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(), + testCtx.getAuthority()), + testCtx.getRedisScopes()).getProvider(); + + IdentityProvider mockIdentityProvider = mock(IdentityProvider.class); + AtomicReference token = new AtomicReference<>(); + doAnswer(invocation -> { + if (token.get() == null) { + token.set(idp.requestToken()); + } + return token.get(); + }).when(mockIdentityProvider).requestToken(); + IdentityProviderConfig idpConfig = mock(IdentityProviderConfig.class); + when(idpConfig.getProvider()).thenReturn(mockIdentityProvider); + + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .identityProviderConfig(idpConfig).expirationRefreshRatio(0.000001F).build(); + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + + token + .set(new SimpleToken("token1", System.currentTimeMillis() - 1, System.currentTimeMillis(), + Collections.singletonMap("oid", idp.requestToken().tryGet("oid")))); + + JedisAccessControlException aclException = assertThrows(JedisAccessControlException.class, + () -> { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + + assertEquals("WRONGPASS invalid username-password pair", aclException.getMessage()); + } + } + + // T.3.4 + // Verify handling of reconnection and re-authentication after a network partition. (use cached token) + // @Test + public void networkPartitionEvictionTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.5F).build(); + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + + triggerNetworkFailure(); + + JedisConnectionException aclException = assertThrows(JedisConnectionException.class, () -> { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + + assertEquals("Unexpected end of stream.", aclException.getMessage()); + Awaitility.await().pollDelay(Durations.ONE_HUNDRED_MILLISECONDS).atMost(Durations.TWO_SECONDS) + .until(() -> { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + return true; + }); + } + } + + private void triggerNetworkFailure() { + HashMap params = new HashMap<>(); + params.put("bdb_id", endpointConfig.getBdbId()); + + FaultInjectionClient.TriggerActionResponse actionResponse = null; + String action = "network_failure"; + try { + log.info("Triggering {}", action); + actionResponse = faultClient.triggerAction(action, params); + } catch (IOException e) { + fail("Fault Injection Server error:" + e.getMessage()); + } + log.info("Action id: {}", actionResponse.getActionId()); + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index 4cbf155dd3..780c82c781 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -9,15 +9,17 @@ import java.util.Arrays; import java.util.Collections; -import java.util.Date; import java.util.List; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import redis.clients.authentication.core.IdentityProvider; import redis.clients.authentication.core.IdentityProviderConfig; @@ -36,52 +38,65 @@ import redis.clients.jedis.commands.ProtocolCommand; public class TokenBasedAuthenticationIntegrationTests { + private static final Logger log = LoggerFactory + .getLogger(TokenBasedAuthenticationIntegrationTests.class); + + private static EndpointConfig endpointConfig; + + @BeforeClass + public static void before() { + try { + endpointConfig = HostAndPorts.getRedisEndpoint("standalone0"); + } catch (IllegalArgumentException e) { + try { + endpointConfig = HostAndPorts.getRedisEndpoint("standalone"); + } catch (IllegalArgumentException ex) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + } - protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); - - @Test - public void testJedisPooledAuth() { - String user = "default"; - String password = endpoint.getPassword(); + @Test + public void testJedisPooledForInitialAuth() { + String user = "default"; + String password = endpointConfig.getPassword(); - IdentityProvider idProvider = mock(IdentityProvider.class); - when(idProvider.requestToken()) - .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000, - System.currentTimeMillis(), Collections.singletonMap("oid", user))); + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000, + System.currentTimeMillis(), Collections.singletonMap("oid", user))); - IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); - when(idProviderConfig.getProvider()).thenReturn(idProvider); + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); - TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() - .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) - .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); - JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() - .tokenAuthConfig(tokenAuthConfig).build(); + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); - try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) { - ArgumentCaptor captor = ArgumentCaptor - .forClass(CommandArguments.class); + try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) { + ArgumentCaptor captor = ArgumentCaptor.forClass(CommandArguments.class); - try (JedisPooled jedis = new JedisPooled(endpoint.getHostAndPort(), clientConfig)) { - jedis.get("key1"); - } + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + jedis.get("key1"); + } - // Verify that the static method was called - mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), - Mockito.atLeast(4)); + // Verify that the static method was called + mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), Mockito.atLeast(4)); - CommandArguments commandArgs = captor.getAllValues().get(0); - List args = StreamSupport.stream(commandArgs.spliterator(), false) - .map(Rawable::getRaw).collect(Collectors.toList()); + CommandArguments commandArgs = captor.getAllValues().get(0); + List args = StreamSupport.stream(commandArgs.spliterator(), false) + .map(Rawable::getRaw).collect(Collectors.toList()); - assertThat(args, - contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes())); + assertThat(args, + contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes())); - List cmds = captor.getAllValues().stream() - .map(item -> item.getCommand()).collect(Collectors.toList()); - assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), - cmds); - } + List cmds = captor.getAllValues().stream().map(item -> item.getCommand()) + .collect(Collectors.toList()); + assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), cmds); } + } } diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 7ac68361aa..83c441d492 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import org.awaitility.Durations; import org.hamcrest.Matchers; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -36,12 +37,15 @@ import redis.clients.authentication.core.TokenListener; import redis.clients.authentication.core.TokenManager; import redis.clients.authentication.core.TokenManagerConfig; +import redis.clients.authentication.core.TokenRequestException; import redis.clients.jedis.ConnectionPool; import redis.clients.jedis.EndpointConfig; -import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.HostAndPort; public class TokenBasedAuthenticationUnitTests { - protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); + + private HostAndPort hnp = new HostAndPort("localhost", 6379); + private EndpointConfig endpoint = new EndpointConfig(hnp, null, null, false); @Test public void testJedisAuthXManagerInstance() { @@ -57,25 +61,51 @@ public void testJedisAuthXManagerInstance() { assertEquals(tokenManagerConfig, context.arguments().get(1)); })) { - new JedisAuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig)); + new AuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig)); } } @Test - public void testJedisAuthXManagerTriggersEvict() throws Exception { + public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() throws Exception { IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 100000, + .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, - new TokenManagerConfig(0.5F, 1000, 1000, null)); - JedisAuthXManager jedisAuthXManager = new JedisAuthXManager(tokenManager); + new TokenManagerConfig(0.4F, 100, 1000, null)); + AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); + + AtomicInteger numberOfEvictions = new AtomicInteger(0); + ConnectionPool pool = new ConnectionPool(hnp, + endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) { + @Override + public void evict() throws Exception { + numberOfEvictions.incrementAndGet(); + super.evict(); + } + }; + + await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS) + .atMost(Durations.FIVE_HUNDRED_MILLISECONDS) + .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); + } + + public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws Exception { + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); + + TokenManager tokenManager = new TokenManager(idProvider, + new TokenManagerConfig(0.9F, 600, 1000, null)); + AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); AtomicInteger numberOfEvictions = new AtomicInteger(0); ConnectionPool pool = new ConnectionPool(endpoint.getHostAndPort(), - endpoint.getClientConfigBuilder().build(), jedisAuthXManager) { + endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) { @Override public void evict() throws Exception { numberOfEvictions.incrementAndGet(); @@ -83,8 +113,9 @@ public void evict() throws Exception { } }; - jedisAuthXManager.start(true); - assertEquals(1, numberOfEvictions.get()); + await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS) + .atMost(Durations.FIVE_HUNDRED_MILLISECONDS) + .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); } public static class TokenManagerConfigWrapper extends TokenManagerConfig { @@ -190,7 +221,7 @@ public void testAuthXManagerReceivesNewToken() TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, 2000, null)); - JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + AuthXManager manager = spy(new AuthXManager(tokenManager)); final Token[] tokenHolder = new Token[1]; doAnswer(invocation -> { @@ -213,11 +244,10 @@ public void testBlockForInitialToken() { TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100))); - JedisAuthXManager manager = new JedisAuthXManager(tokenManager); - ExecutionException e = assertThrows(ExecutionException.class, () -> manager.start(true)); + AuthXManager manager = new AuthXManager(tokenManager); + TokenRequestException e = assertThrows(TokenRequestException.class, () -> manager.start(true)); - assertEquals(exceptionMessage, - e.getCause().getCause().getCause().getCause().getMessage()); + assertEquals(exceptionMessage, e.getCause().getCause().getCause().getMessage()); } @Test @@ -231,9 +261,9 @@ public void testNoBlockForInitialToken() }; TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, - 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 100))); + 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 0))); - JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + AuthXManager manager = spy(new AuthXManager(tokenManager)); manager.start(false); requesLatch.await(); @@ -296,7 +326,7 @@ public void testTokenManagerWithHangingTokenRequest() TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100))); - JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + AuthXManager manager = spy(new AuthXManager(tokenManager)); manager.start(false); requesLatch.await(); verify(manager, never()).onError(any());