From da2dddfe316363c618e30420c8e9ee551e792f1f Mon Sep 17 00:00:00 2001 From: Hannes Herrmann Date: Thu, 29 Aug 2024 14:34:29 +0200 Subject: [PATCH] feat: Make api clients with interceptors cloneable (#566) Clients are cloneable now. BREAKING CHANGE: Return types of the client builder have changed. Please update to the newest returned types to use the client builder. This removed the ChainedInterceptor which essentially has hidden the cloneable feature for interceptors. --- src/api/clients.rs | 330 ++++++++++++++++------------------------ src/api/interceptors.rs | 156 +++++++++++++++---- 2 files changed, 261 insertions(+), 225 deletions(-) diff --git a/src/api/clients.rs b/src/api/clients.rs index 96bc4a2..d68a538 100644 --- a/src/api/clients.rs +++ b/src/api/clients.rs @@ -6,10 +6,9 @@ use std::error::Error; use custom_error::custom_error; -use tonic::codegen::InterceptedService; +use tonic::codegen::{Body, Bytes, InterceptedService, StdError}; use tonic::service::Interceptor; use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; -use tonic::{Request, Status}; #[cfg(feature = "interceptors")] use crate::api::interceptors::{AccessTokenInterceptor, ServiceAccountInterceptor}; @@ -45,63 +44,55 @@ custom_error! { TlsInitializationError = "could not setup tls connection", } -#[cfg(feature = "interceptors")] -enum AuthType { - None, - AccessToken(String), - ServiceAccount(ServiceAccount, Option), +/// A builder to create configured gRPC clients for ZITADEL API access. +/// The builder accepts the api endpoint and (depending on activated features) +/// an authentication method. +pub struct ClientBuilder { + api_endpoint: String, + interceptor: T, } -/// Helper [Interceptor] that allows chaining of multiple interceptors. -/// This is used to help return the same type in all builder methods like -/// [ClientBuilder::build_management_client]. Otherwise, each interceptor -/// would create its own return type. With this interceptor, the return type -/// stays the same and is not dependent on the authentication type used. -/// The builder can always return `Client>`. -pub struct ChainedInterceptor { - interceptors: Vec>, +pub trait BuildInterceptedService { + type Target; + fn build_service(self, channel: Channel) -> Self::Target; } -impl ChainedInterceptor { - pub(crate) fn new() -> Self { - Self { - interceptors: Vec::new(), - } - } +pub struct NoInterceptor; - #[cfg(feature = "interceptors")] - pub(crate) fn add_interceptor(mut self, interceptor: Box) -> Self { - self.interceptors.push(interceptor); - self +impl BuildInterceptedService for NoInterceptor { + type Target = Channel; + fn build_service(self, channel: Channel) -> Self::Target { + channel } } -impl Interceptor for ChainedInterceptor { - fn call(&mut self, request: Request<()>) -> Result, Status> { - let mut request = request; - for interceptor in &mut self.interceptors { - request = interceptor.call(request)?; - } - Ok(request) +impl BuildInterceptedService for T +where + T: Interceptor, +{ + type Target = InterceptedService; + fn build_service(self, channel: Channel) -> Self::Target { + InterceptedService::new(channel, self) } } -/// A builder to create configured gRPC clients for ZITADEL API access. -/// The builder accepts the api endpoint and (depending on activated features) -/// an authentication method. -pub struct ClientBuilder { - api_endpoint: String, - #[cfg(feature = "interceptors")] - auth_type: AuthType, -} - -impl ClientBuilder { - /// Create a new client builder with the the provided endpoint. - pub fn new(api_endpoint: &str) -> Self { - Self { +impl ClientBuilder { + /// Create a new client builder with the provided endpoint. + pub fn new(api_endpoint: &str) -> ClientBuilder { + ClientBuilder { api_endpoint: api_endpoint.to_string(), - #[cfg(feature = "interceptors")] - auth_type: AuthType::None, + interceptor: NoInterceptor, + } + } + + /// Configure the client builder to inject a custom interceptor, + /// which can be used to modify the [Request][tonic::request::Request] before being sent. + /// + /// See [Interceptor][tonic::service::Interceptor] for more details. + pub fn with_interceptor(self, interceptor: I) -> ClientBuilder { + ClientBuilder { + api_endpoint: self.api_endpoint, + interceptor, } } @@ -112,9 +103,8 @@ impl ClientBuilder { /// Clients with this authentication method will have the [`AccessTokenInterceptor`] /// attached. #[cfg(feature = "interceptors")] - pub fn with_access_token(mut self, access_token: &str) -> Self { - self.auth_type = AuthType::AccessToken(access_token.to_string()); - self + pub fn with_access_token(self, access_token: &str) -> ClientBuilder { + self.with_interceptor(AccessTokenInterceptor::new(access_token)) } /// Configure the client builder to use a [`ServiceAccount`][crate::credentials::ServiceAccount]. @@ -125,63 +115,60 @@ impl ClientBuilder { /// that fetches an access token from ZITADEL and renewes it when it expires. #[cfg(feature = "interceptors")] pub fn with_service_account( - mut self, + self, service_account: &ServiceAccount, auth_options: Option, - ) -> Self { - self.auth_type = AuthType::ServiceAccount(service_account.clone(), auth_options); - self + ) -> ClientBuilder { + let interceptor = ServiceAccountInterceptor::new( + &self.api_endpoint, + service_account, + auth_options.clone(), + ); + self.with_interceptor(interceptor) } +} +impl ClientBuilder +where + T: BuildInterceptedService, + T::Target: tonic::client::GrpcService, + >::ResponseBody: + Body + Send + 'static, + <>::ResponseBody as Body>::Error: + Into + Send, +{ /// Create a new [`AdminServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-admin-v1")] - pub async fn build_admin_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(AdminServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + pub async fn build_admin_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(AdminServiceClient::new(channel)) } /// Create a new [`AuthServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-auth-v1")] - pub async fn build_auth_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(AuthServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + pub async fn build_auth_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(AuthServiceClient::new(channel)) } /// Create a new [`ManagementServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -189,45 +176,31 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-management-v1")] pub async fn build_management_client( - &self, - ) -> Result< - ManagementServiceClient>, - Box, - > { - let channel = get_channel(&self.api_endpoint).await?; - Ok(ManagementServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(ManagementServiceClient::new(channel)) } /// Create a new [`OidcServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-oidc-v2")] - pub async fn build_oidc_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(OidcServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + pub async fn build_oidc_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(OidcServiceClient::new(channel)) } /// Create a new [`OrganizationServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -235,23 +208,16 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-org-v2")] pub async fn build_organization_client( - &self, - ) -> Result< - OrganizationServiceClient>, - Box, - > { - let channel = get_channel(&self.api_endpoint).await?; - Ok(OrganizationServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(OrganizationServiceClient::new(channel)) } /// Create a new [`SessionServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -259,21 +225,16 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-session-v2")] pub async fn build_session_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(SessionServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(SessionServiceClient::new(channel)) } /// Create a new [`SettingsServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -281,23 +242,16 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-settings-v2")] pub async fn build_settings_client( - &self, - ) -> Result< - SettingsServiceClient>, - Box, - > { - let channel = get_channel(&self.api_endpoint).await?; - Ok(SettingsServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(SettingsServiceClient::new(channel)) } /// Create a new [`SystemServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -305,59 +259,27 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-system-v1")] pub async fn build_system_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(SystemServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(SystemServiceClient::new(channel)) } /// Create a new [`UserServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-user-v2")] - pub async fn build_user_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(UserServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) - } - - fn get_chained_interceptor(&self) -> ChainedInterceptor { - #[allow(unused_mut)] - let mut interceptor = ChainedInterceptor::new(); - #[cfg(feature = "interceptors")] - match &self.auth_type { - AuthType::AccessToken(token) => { - interceptor = - interceptor.add_interceptor(Box::new(AccessTokenInterceptor::new(token))); - } - AuthType::ServiceAccount(service_account, auth_options) => { - interceptor = - interceptor.add_interceptor(Box::new(ServiceAccountInterceptor::new( - &self.api_endpoint, - service_account, - auth_options.clone(), - ))); - } - _ => {} - } - - interceptor + pub async fn build_user_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(UserServiceClient::new(channel)) } } @@ -378,6 +300,7 @@ async fn get_channel(api_endpoint: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use tonic::Request; const ZITADEL_URL: &str = "https://zitadel-libraries-l8boqa.zitadel.cloud"; const SERVICE_ACCOUNT: &str = r#" @@ -388,23 +311,33 @@ mod tests { "userId": "181828061098934529" }"#; - #[test] - fn client_builder_without_auth_passes_requests() { - let mut interceptor = ClientBuilder::new(ZITADEL_URL).get_chained_interceptor(); - let request = Request::new(()); - - assert!(request.metadata().is_empty()); - - let request = interceptor.call(request).unwrap(); - - assert!(request.metadata().is_empty()); + #[tokio::test] + async fn clients_are_cloneable() { + let access_token_client = ClientBuilder::new(ZITADEL_URL) + .with_access_token("token") + .build_user_client() + .await + .unwrap(); + let _cloned = access_token_client.clone(); + + let service_account_client = ClientBuilder::new(ZITADEL_URL) + .with_service_account( + &ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(), + None, + ) + .build_user_client() + .await + .unwrap(); + + let _cloned = service_account_client.clone(); } #[test] fn client_builder_with_access_token_attaches_it() { let mut interceptor = ClientBuilder::new(ZITADEL_URL) .with_access_token("token") - .get_chained_interceptor(); + .interceptor; + let request = Request::new(()); assert!(request.metadata().is_empty()); @@ -423,7 +356,8 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ClientBuilder::new(ZITADEL_URL) .with_service_account(&sa, None) - .get_chained_interceptor(); + .interceptor; + let request = Request::new(()); assert!(request.metadata().is_empty()); diff --git a/src/api/interceptors.rs b/src/api/interceptors.rs index 813b5e1..52f93d9 100644 --- a/src/api/interceptors.rs +++ b/src/api/interceptors.rs @@ -3,6 +3,8 @@ //! interceptors is to authenticate the clients to ZITADEL with //! provided credentials. +use std::ops::Deref; +use std::sync::{Arc, RwLock}; use std::thread; use tokio::runtime::Builder; @@ -41,6 +43,7 @@ use crate::credentials::{AuthenticationOptions, ServiceAccount}; /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct AccessTokenInterceptor { access_token: String, } @@ -125,12 +128,21 @@ impl Interceptor for AccessTokenInterceptor { /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct ServiceAccountInterceptor { + inner: Arc, +} + +struct ServiceAccountInterceptorInner { audience: String, service_account: ServiceAccount, auth_options: AuthenticationOptions, - token: Option, - token_expiry: Option, + state: RwLock>, +} + +struct ServiceAccountInterceptorState { + token: String, + token_expiry: time::OffsetDateTime, } impl ServiceAccountInterceptor { @@ -144,11 +156,12 @@ impl ServiceAccountInterceptor { auth_options: Option, ) -> Self { Self { - audience: audience.to_string(), - service_account: service_account.clone(), - auth_options: auth_options.unwrap_or_default(), - token: None, - token_expiry: None, + inner: Arc::new(ServiceAccountInterceptorInner { + audience: audience.to_string(), + service_account: service_account.clone(), + auth_options: auth_options.unwrap_or_default(), + state: RwLock::new(None), + }), } } } @@ -157,25 +170,32 @@ impl Interceptor for ServiceAccountInterceptor { fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { let meta = request.metadata_mut(); if !meta.contains_key("authorization") { - if let Some(token) = &self.token { - if let Some(expiry) = self.token_expiry { - if expiry > time::OffsetDateTime::now_utc() { - meta.insert( - "authorization", - format!("Bearer {}", token).parse().unwrap(), - ); - - return Ok(request); - } + // We unwrap the RWLock to propagate the error if any + // thread panics and the lock is poisoned + let state_read_guard = self.inner.state.read().unwrap(); + + if let Some(ServiceAccountInterceptorState { + token, + token_expiry, + }) = state_read_guard.deref() + { + if token_expiry > &time::OffsetDateTime::now_utc() { + meta.insert( + "authorization", + format!("Bearer {}", token).parse().unwrap(), + ); + + return Ok(request); } } + drop(state_read_guard); - let aud = self.audience.clone(); - let auth = self.auth_options.clone(); - let sa = self.service_account.clone(); + let aud = self.inner.audience.clone(); + let auth = self.inner.auth_options.clone(); + let sa = self.inner.service_account.clone(); let token = thread::spawn(move || { - let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + let rt = Builder::new_current_thread().enable_all().build().unwrap(); rt.block_on(async { sa.authenticate_with_options(&aud, &auth) .await @@ -187,8 +207,14 @@ impl Interceptor for ServiceAccountInterceptor { .join() .map_err(|_| Status::internal("could not fetch token"))??; - self.token = Some(token.clone()); - self.token_expiry = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(59)); + // We unwrap the RWLock to propagate the error if any + // thread panics and the lock is poisoned + let mut state_write_guard = self.inner.state.write().unwrap(); + + *state_write_guard = Some(ServiceAccountInterceptorState { + token: token.clone(), + token_expiry: time::OffsetDateTime::now_utc() + time::Duration::minutes(59), + }); meta.insert( "authorization", @@ -288,6 +314,46 @@ mod tests { .is_empty()); } + #[test] + fn service_account_interceptor_can_be_cloned_and_shares_token_sync_context() { + let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); + let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); + let mut second_interceptor = interceptor.clone(); + let request = Request::new(()); + let second_request = Request::new(()); + + assert!(request.metadata().is_empty()); + assert!(second_request.metadata().is_empty()); + + let request = interceptor.call(request).unwrap(); + let second_request = second_interceptor.call(second_request).unwrap(); + + assert_eq!( + request.metadata().get("authorization"), + second_request.metadata().get("authorization") + ); + } + + #[tokio::test] + async fn service_account_interceptor_can_be_cloned_and_shares_token_async_context() { + let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); + let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); + let mut second_interceptor = interceptor.clone(); + let request = Request::new(()); + let second_request = Request::new(()); + + assert!(request.metadata().is_empty()); + assert!(second_request.metadata().is_empty()); + + let request = interceptor.call(request).unwrap(); + let second_request = second_interceptor.call(second_request).unwrap(); + + assert_eq!( + request.metadata().get("authorization"), + second_request.metadata().get("authorization") + ); + } + #[test] fn service_account_interceptor_ignore_existing_auth_metadata_sync_context() { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); @@ -333,10 +399,28 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); interceptor.call(Request::new(())).unwrap(); - let token = interceptor.token.clone().unwrap(); + let token = interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + .clone(); interceptor.call(Request::new(())).unwrap(); - assert_eq!(token, interceptor.token.unwrap()); + assert_eq!( + token, + interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + ); } #[tokio::test] @@ -344,9 +428,27 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); interceptor.call(Request::new(())).unwrap(); - let token = interceptor.token.clone().unwrap(); + let token = interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + .clone(); interceptor.call(Request::new(())).unwrap(); - assert_eq!(token, interceptor.token.unwrap()); + assert_eq!( + token, + interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + ); } }