diff --git a/src/api/clients.rs b/src/api/clients.rs index 96bc4a2..d68a538 100644 --- a/src/api/clients.rs +++ b/src/api/clients.rs @@ -6,10 +6,9 @@ use std::error::Error; use custom_error::custom_error; -use tonic::codegen::InterceptedService; +use tonic::codegen::{Body, Bytes, InterceptedService, StdError}; use tonic::service::Interceptor; use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; -use tonic::{Request, Status}; #[cfg(feature = "interceptors")] use crate::api::interceptors::{AccessTokenInterceptor, ServiceAccountInterceptor}; @@ -45,63 +44,55 @@ custom_error! { TlsInitializationError = "could not setup tls connection", } -#[cfg(feature = "interceptors")] -enum AuthType { - None, - AccessToken(String), - ServiceAccount(ServiceAccount, Option), +/// A builder to create configured gRPC clients for ZITADEL API access. +/// The builder accepts the api endpoint and (depending on activated features) +/// an authentication method. +pub struct ClientBuilder { + api_endpoint: String, + interceptor: T, } -/// Helper [Interceptor] that allows chaining of multiple interceptors. -/// This is used to help return the same type in all builder methods like -/// [ClientBuilder::build_management_client]. Otherwise, each interceptor -/// would create its own return type. With this interceptor, the return type -/// stays the same and is not dependent on the authentication type used. -/// The builder can always return `Client>`. -pub struct ChainedInterceptor { - interceptors: Vec>, +pub trait BuildInterceptedService { + type Target; + fn build_service(self, channel: Channel) -> Self::Target; } -impl ChainedInterceptor { - pub(crate) fn new() -> Self { - Self { - interceptors: Vec::new(), - } - } +pub struct NoInterceptor; - #[cfg(feature = "interceptors")] - pub(crate) fn add_interceptor(mut self, interceptor: Box) -> Self { - self.interceptors.push(interceptor); - self +impl BuildInterceptedService for NoInterceptor { + type Target = Channel; + fn build_service(self, channel: Channel) -> Self::Target { + channel } } -impl Interceptor for ChainedInterceptor { - fn call(&mut self, request: Request<()>) -> Result, Status> { - let mut request = request; - for interceptor in &mut self.interceptors { - request = interceptor.call(request)?; - } - Ok(request) +impl BuildInterceptedService for T +where + T: Interceptor, +{ + type Target = InterceptedService; + fn build_service(self, channel: Channel) -> Self::Target { + InterceptedService::new(channel, self) } } -/// A builder to create configured gRPC clients for ZITADEL API access. -/// The builder accepts the api endpoint and (depending on activated features) -/// an authentication method. -pub struct ClientBuilder { - api_endpoint: String, - #[cfg(feature = "interceptors")] - auth_type: AuthType, -} - -impl ClientBuilder { - /// Create a new client builder with the the provided endpoint. - pub fn new(api_endpoint: &str) -> Self { - Self { +impl ClientBuilder { + /// Create a new client builder with the provided endpoint. + pub fn new(api_endpoint: &str) -> ClientBuilder { + ClientBuilder { api_endpoint: api_endpoint.to_string(), - #[cfg(feature = "interceptors")] - auth_type: AuthType::None, + interceptor: NoInterceptor, + } + } + + /// Configure the client builder to inject a custom interceptor, + /// which can be used to modify the [Request][tonic::request::Request] before being sent. + /// + /// See [Interceptor][tonic::service::Interceptor] for more details. + pub fn with_interceptor(self, interceptor: I) -> ClientBuilder { + ClientBuilder { + api_endpoint: self.api_endpoint, + interceptor, } } @@ -112,9 +103,8 @@ impl ClientBuilder { /// Clients with this authentication method will have the [`AccessTokenInterceptor`] /// attached. #[cfg(feature = "interceptors")] - pub fn with_access_token(mut self, access_token: &str) -> Self { - self.auth_type = AuthType::AccessToken(access_token.to_string()); - self + pub fn with_access_token(self, access_token: &str) -> ClientBuilder { + self.with_interceptor(AccessTokenInterceptor::new(access_token)) } /// Configure the client builder to use a [`ServiceAccount`][crate::credentials::ServiceAccount]. @@ -125,63 +115,60 @@ impl ClientBuilder { /// that fetches an access token from ZITADEL and renewes it when it expires. #[cfg(feature = "interceptors")] pub fn with_service_account( - mut self, + self, service_account: &ServiceAccount, auth_options: Option, - ) -> Self { - self.auth_type = AuthType::ServiceAccount(service_account.clone(), auth_options); - self + ) -> ClientBuilder { + let interceptor = ServiceAccountInterceptor::new( + &self.api_endpoint, + service_account, + auth_options.clone(), + ); + self.with_interceptor(interceptor) } +} +impl ClientBuilder +where + T: BuildInterceptedService, + T::Target: tonic::client::GrpcService, + >::ResponseBody: + Body + Send + 'static, + <>::ResponseBody as Body>::Error: + Into + Send, +{ /// Create a new [`AdminServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-admin-v1")] - pub async fn build_admin_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(AdminServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + pub async fn build_admin_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(AdminServiceClient::new(channel)) } /// Create a new [`AuthServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-auth-v1")] - pub async fn build_auth_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(AuthServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + pub async fn build_auth_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(AuthServiceClient::new(channel)) } /// Create a new [`ManagementServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -189,45 +176,31 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-management-v1")] pub async fn build_management_client( - &self, - ) -> Result< - ManagementServiceClient>, - Box, - > { - let channel = get_channel(&self.api_endpoint).await?; - Ok(ManagementServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(ManagementServiceClient::new(channel)) } /// Create a new [`OidcServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-oidc-v2")] - pub async fn build_oidc_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(OidcServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + pub async fn build_oidc_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(OidcServiceClient::new(channel)) } /// Create a new [`OrganizationServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -235,23 +208,16 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-org-v2")] pub async fn build_organization_client( - &self, - ) -> Result< - OrganizationServiceClient>, - Box, - > { - let channel = get_channel(&self.api_endpoint).await?; - Ok(OrganizationServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(OrganizationServiceClient::new(channel)) } /// Create a new [`SessionServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -259,21 +225,16 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-session-v2")] pub async fn build_session_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(SessionServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(SessionServiceClient::new(channel)) } /// Create a new [`SettingsServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -281,23 +242,16 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-settings-v2")] pub async fn build_settings_client( - &self, - ) -> Result< - SettingsServiceClient>, - Box, - > { - let channel = get_channel(&self.api_endpoint).await?; - Ok(SettingsServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(SettingsServiceClient::new(channel)) } /// Create a new [`SystemServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint @@ -305,59 +259,27 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-system-v1")] pub async fn build_system_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(SystemServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) + self, + ) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(SystemServiceClient::new(channel)) } /// Create a new [`UserServiceClient`]. /// - /// Depending on the configured authentication method, the client has - /// specialised interceptors attached. - /// /// ### Errors /// /// This function returns a [`ClientError`] if the provided API endpoint /// cannot be parsed into a valid URL or if the connection to the endpoint /// is not possible. #[cfg(feature = "api-user-v2")] - pub async fn build_user_client( - &self, - ) -> Result>, Box> - { - let channel = get_channel(&self.api_endpoint).await?; - Ok(UserServiceClient::with_interceptor( - channel, - self.get_chained_interceptor(), - )) - } - - fn get_chained_interceptor(&self) -> ChainedInterceptor { - #[allow(unused_mut)] - let mut interceptor = ChainedInterceptor::new(); - #[cfg(feature = "interceptors")] - match &self.auth_type { - AuthType::AccessToken(token) => { - interceptor = - interceptor.add_interceptor(Box::new(AccessTokenInterceptor::new(token))); - } - AuthType::ServiceAccount(service_account, auth_options) => { - interceptor = - interceptor.add_interceptor(Box::new(ServiceAccountInterceptor::new( - &self.api_endpoint, - service_account, - auth_options.clone(), - ))); - } - _ => {} - } - - interceptor + pub async fn build_user_client(self) -> Result, Box> { + let channel = self + .interceptor + .build_service(get_channel(&self.api_endpoint).await?); + Ok(UserServiceClient::new(channel)) } } @@ -378,6 +300,7 @@ async fn get_channel(api_endpoint: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use tonic::Request; const ZITADEL_URL: &str = "https://zitadel-libraries-l8boqa.zitadel.cloud"; const SERVICE_ACCOUNT: &str = r#" @@ -388,23 +311,33 @@ mod tests { "userId": "181828061098934529" }"#; - #[test] - fn client_builder_without_auth_passes_requests() { - let mut interceptor = ClientBuilder::new(ZITADEL_URL).get_chained_interceptor(); - let request = Request::new(()); - - assert!(request.metadata().is_empty()); - - let request = interceptor.call(request).unwrap(); - - assert!(request.metadata().is_empty()); + #[tokio::test] + async fn clients_are_cloneable() { + let access_token_client = ClientBuilder::new(ZITADEL_URL) + .with_access_token("token") + .build_user_client() + .await + .unwrap(); + let _cloned = access_token_client.clone(); + + let service_account_client = ClientBuilder::new(ZITADEL_URL) + .with_service_account( + &ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(), + None, + ) + .build_user_client() + .await + .unwrap(); + + let _cloned = service_account_client.clone(); } #[test] fn client_builder_with_access_token_attaches_it() { let mut interceptor = ClientBuilder::new(ZITADEL_URL) .with_access_token("token") - .get_chained_interceptor(); + .interceptor; + let request = Request::new(()); assert!(request.metadata().is_empty()); @@ -423,7 +356,8 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ClientBuilder::new(ZITADEL_URL) .with_service_account(&sa, None) - .get_chained_interceptor(); + .interceptor; + let request = Request::new(()); assert!(request.metadata().is_empty()); diff --git a/src/api/interceptors.rs b/src/api/interceptors.rs index 813b5e1..52f93d9 100644 --- a/src/api/interceptors.rs +++ b/src/api/interceptors.rs @@ -3,6 +3,8 @@ //! interceptors is to authenticate the clients to ZITADEL with //! provided credentials. +use std::ops::Deref; +use std::sync::{Arc, RwLock}; use std::thread; use tokio::runtime::Builder; @@ -41,6 +43,7 @@ use crate::credentials::{AuthenticationOptions, ServiceAccount}; /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct AccessTokenInterceptor { access_token: String, } @@ -125,12 +128,21 @@ impl Interceptor for AccessTokenInterceptor { /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct ServiceAccountInterceptor { + inner: Arc, +} + +struct ServiceAccountInterceptorInner { audience: String, service_account: ServiceAccount, auth_options: AuthenticationOptions, - token: Option, - token_expiry: Option, + state: RwLock>, +} + +struct ServiceAccountInterceptorState { + token: String, + token_expiry: time::OffsetDateTime, } impl ServiceAccountInterceptor { @@ -144,11 +156,12 @@ impl ServiceAccountInterceptor { auth_options: Option, ) -> Self { Self { - audience: audience.to_string(), - service_account: service_account.clone(), - auth_options: auth_options.unwrap_or_default(), - token: None, - token_expiry: None, + inner: Arc::new(ServiceAccountInterceptorInner { + audience: audience.to_string(), + service_account: service_account.clone(), + auth_options: auth_options.unwrap_or_default(), + state: RwLock::new(None), + }), } } } @@ -157,25 +170,32 @@ impl Interceptor for ServiceAccountInterceptor { fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { let meta = request.metadata_mut(); if !meta.contains_key("authorization") { - if let Some(token) = &self.token { - if let Some(expiry) = self.token_expiry { - if expiry > time::OffsetDateTime::now_utc() { - meta.insert( - "authorization", - format!("Bearer {}", token).parse().unwrap(), - ); - - return Ok(request); - } + // We unwrap the RWLock to propagate the error if any + // thread panics and the lock is poisoned + let state_read_guard = self.inner.state.read().unwrap(); + + if let Some(ServiceAccountInterceptorState { + token, + token_expiry, + }) = state_read_guard.deref() + { + if token_expiry > &time::OffsetDateTime::now_utc() { + meta.insert( + "authorization", + format!("Bearer {}", token).parse().unwrap(), + ); + + return Ok(request); } } + drop(state_read_guard); - let aud = self.audience.clone(); - let auth = self.auth_options.clone(); - let sa = self.service_account.clone(); + let aud = self.inner.audience.clone(); + let auth = self.inner.auth_options.clone(); + let sa = self.inner.service_account.clone(); let token = thread::spawn(move || { - let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + let rt = Builder::new_current_thread().enable_all().build().unwrap(); rt.block_on(async { sa.authenticate_with_options(&aud, &auth) .await @@ -187,8 +207,14 @@ impl Interceptor for ServiceAccountInterceptor { .join() .map_err(|_| Status::internal("could not fetch token"))??; - self.token = Some(token.clone()); - self.token_expiry = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(59)); + // We unwrap the RWLock to propagate the error if any + // thread panics and the lock is poisoned + let mut state_write_guard = self.inner.state.write().unwrap(); + + *state_write_guard = Some(ServiceAccountInterceptorState { + token: token.clone(), + token_expiry: time::OffsetDateTime::now_utc() + time::Duration::minutes(59), + }); meta.insert( "authorization", @@ -288,6 +314,46 @@ mod tests { .is_empty()); } + #[test] + fn service_account_interceptor_can_be_cloned_and_shares_token_sync_context() { + let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); + let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); + let mut second_interceptor = interceptor.clone(); + let request = Request::new(()); + let second_request = Request::new(()); + + assert!(request.metadata().is_empty()); + assert!(second_request.metadata().is_empty()); + + let request = interceptor.call(request).unwrap(); + let second_request = second_interceptor.call(second_request).unwrap(); + + assert_eq!( + request.metadata().get("authorization"), + second_request.metadata().get("authorization") + ); + } + + #[tokio::test] + async fn service_account_interceptor_can_be_cloned_and_shares_token_async_context() { + let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); + let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); + let mut second_interceptor = interceptor.clone(); + let request = Request::new(()); + let second_request = Request::new(()); + + assert!(request.metadata().is_empty()); + assert!(second_request.metadata().is_empty()); + + let request = interceptor.call(request).unwrap(); + let second_request = second_interceptor.call(second_request).unwrap(); + + assert_eq!( + request.metadata().get("authorization"), + second_request.metadata().get("authorization") + ); + } + #[test] fn service_account_interceptor_ignore_existing_auth_metadata_sync_context() { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); @@ -333,10 +399,28 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); interceptor.call(Request::new(())).unwrap(); - let token = interceptor.token.clone().unwrap(); + let token = interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + .clone(); interceptor.call(Request::new(())).unwrap(); - assert_eq!(token, interceptor.token.unwrap()); + assert_eq!( + token, + interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + ); } #[tokio::test] @@ -344,9 +428,27 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); interceptor.call(Request::new(())).unwrap(); - let token = interceptor.token.clone().unwrap(); + let token = interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + .clone(); interceptor.call(Request::new(())).unwrap(); - assert_eq!(token, interceptor.token.unwrap()); + assert_eq!( + token, + interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + ); } }