diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml b/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml index e38d2ee0..eca4899e 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml @@ -16,6 +16,7 @@ thiserror.workspace = true didcomm = { workspace = true, features = ["uniffi"] } hyper = { workspace = true, features = ["full"] } axum = { workspace = true, features = ["macros"] } +tokio = "1.27.0" [dev-dependencies] keystore = { workspace = true, features = ["test-utils"] } diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/src/error.rs b/crates/web-plugins/didcomm-messaging/protocols/forward/src/error.rs index 5076f3bb..376c2247 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/src/error.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/src/error.rs @@ -10,6 +10,8 @@ pub(crate) enum ForwardError { UncoordinatedSender, #[error("Internal server error")] InternalServerError, + #[error("Service unavailable")] + CircuitOpen, } impl IntoResponse for ForwardError { @@ -18,6 +20,7 @@ impl IntoResponse for ForwardError { ForwardError::MalformedBody => StatusCode::BAD_REQUEST, ForwardError::UncoordinatedSender => StatusCode::UNAUTHORIZED, ForwardError::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR, + ForwardError::CircuitOpen => StatusCode::SERVICE_UNAVAILABLE, }; let body = Json(serde_json::json!({ diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs b/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs index 2d1a86fc..7cb1aafd 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs @@ -4,48 +4,84 @@ use didcomm::{AttachmentData, Message}; use mongodb::bson::doc; use serde_json::{json, Value}; use shared::{ + circuit_breaker::CircuitBreaker, repository::entity::{Connection, RoutedMessage}, + retry::{retry_async, RetryOptions}, state::{AppState, AppStateRepository}, }; use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; /// Mediator receives forwarded messages, extract the next field in the message body, and the attachments in the message /// then stores the attachment with the next field as key for pickup pub(crate) async fn mediator_forward_process( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, ForwardError> { - let AppStateRepository { - message_repository, - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(ForwardError::InternalServerError)?; - - let next = match checks(&message, connection_repository).await.ok() { - Some(next) => Ok(next), - None => Err(ForwardError::InternalServerError), - }; - - let attachments = message.attachments.unwrap_or_default(); - for attachment in attachments { - let attached = match attachment.data { - AttachmentData::Json { value: data } => data.json, - AttachmentData::Base64 { value: data } => json!(data.base64), - AttachmentData::Links { value: data } => json!(data.links), - }; - message_repository - .store(RoutedMessage { - id: None, - message: attached, - recipient_did: next.as_ref().unwrap().to_owned(), - }) - .await - .map_err(|_| ForwardError::InternalServerError)?; + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = Arc::clone(&state); + let message = message.clone(); + async move { + let AppStateRepository { + message_repository, + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or_else(|| ForwardError::InternalServerError)?; + + let next = match checks(&message, connection_repository).await.ok() { + Some(next) => Ok(next), + None => Err(ForwardError::InternalServerError), + }?; + + let attachments = message.attachments.unwrap_or_default(); + for attachment in attachments { + let attached = match attachment.data { + AttachmentData::Json { value: data } => data.json, + AttachmentData::Base64 { value: data } => json!(data.base64), + AttachmentData::Links { value: data } => json!(data.links), + }; + retry_async( + || { + let attached = attached.clone(); + let recipient_did = next.to_owned(); + + async move { + message_repository + .store(RoutedMessage { + id: None, + message: attached, + recipient_did, + }) + .await + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(1)), + ) + .await + .map_err(|_| ForwardError::InternalServerError)?; + } + Ok::, ForwardError>(None) + } + }) + .await; + + match result { + Some(Ok(None)) => Ok(None), + Some(Ok(Some(_))) => Err(ForwardError::InternalServerError), + Some(Err(err)) => Err(err), + None => Err(ForwardError::CircuitOpen), } - Ok(None) } async fn checks( @@ -83,6 +119,7 @@ mod test { use keystore::Secrets; use serde_json::json; use shared::{ + circuit_breaker, repository::{ entity::Connection, tests::{MockConnectionRepository, MockMessagesRepository}, @@ -166,9 +203,16 @@ mod test { .await .expect("Unable unpack"); - let msg = mediator_forward_process(Arc::new(state.clone()), msg) - .await - .unwrap(); + // Wrap the CircuitBreaker in Arc and Mutex + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let msg: Option = mediator_forward_process( + Arc::new(state.clone()), + msg, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap(); println!("Mediator1 is forwarding message \n{:?}\n", msg); } diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs b/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs index 6f860a59..c2f2cf38 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs @@ -3,8 +3,9 @@ use async_trait::async_trait; use axum::response::{IntoResponse, Response}; use didcomm::Message; use message_api::{MessageHandler, MessagePlugin, MessageRouter}; -use shared::state::AppState; -use std::sync::Arc; +use shared::{circuit_breaker::CircuitBreaker, state::AppState}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; /// Represents the routing protocol plugin. pub struct RoutingProtocol; @@ -18,7 +19,13 @@ impl MessageHandler for ForwardHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::mediator_forward_process(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + // Pass the state, msg, and the circuit_breaker as arguments + crate::handler::mediator_forward_process(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } diff --git a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs index 7e01a08f..2a15b917 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs @@ -15,6 +15,8 @@ pub(crate) enum MediationError { UnexpectedMessageFormat, #[error("internal server error")] InternalServerError, + #[error("service unavailable")] + CircuitOpen, } impl IntoResponse for MediationError { @@ -26,6 +28,7 @@ impl IntoResponse for MediationError { MediationError::UncoordinatedSender => StatusCode::UNAUTHORIZED, MediationError::UnexpectedMessageFormat => StatusCode::BAD_REQUEST, MediationError::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR, + MediationError::CircuitOpen => StatusCode::SERVICE_UNAVAILABLE, }; let body = Json(serde_json::json!({ diff --git a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs index 2a135813..5a5ff766 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs @@ -19,17 +19,21 @@ use keystore::Secrets; use mongodb::bson::doc; use serde_json::json; use shared::{ + circuit_breaker::CircuitBreaker, midlw::ensure_transport_return_route_is_decorated_all, repository::entity::Connection, + retry::{retry_async, RetryOptions}, state::{AppState, AppStateRepository}, }; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; use uuid::Uuid; /// Process a DIDComm mediate request pub(crate) async fn process_mediate_request( state: Arc, plain_message: Message, + circuit_breaker: Arc>, ) -> Result, MediationError> { // This is to Check message type compliance ensure_jwm_type_is_mediation_request(&plain_message)?; @@ -42,116 +46,164 @@ pub(crate) async fn process_mediate_request( let sender_did = plain_message.from.as_ref().unwrap(); - // Retrieve repository to connection entities - - let AppStateRepository { - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - // If there is already mediation, send mediate deny - if let Some(_connection) = connection_repository - .find_one_by(doc! { "client_did": sender_did}) - .await - .map_err(|_| MediationError::InternalServerError)? - { - tracing::info!("Sending mediate deny."); - Ok(Some( - Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - MEDIATE_DENY_2_0.to_string(), - json!(MediationDeny { - id: format!("urn:uuid:{}", Uuid::new_v4()), - message_type: MEDIATE_DENY_2_0.to_string(), - }), - ) - .to(sender_did.clone()) - .from(mediator_did.clone()) - .finalize(), - )) - } else { - /* Issue mediate grant response */ - tracing::info!("Sending mediate grant."); - // Create routing, store it and send mediation grant - let (routing_did, auth_keys, agreem_keys) = - generate_did_peer(state.public_domain.to_string()); - - let AppStateRepository { keystore, .. } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - let diddoc = state - .did_resolver - .resolve(&routing_did) - .await - .map_err(|err| { - tracing::error!("Failed to resolve DID: {:?}", err); - MediationError::InternalServerError - })? - .ok_or(MediationError::InternalServerError)?; - - let agreem_keys_jwk: Jwk = agreem_keys.try_into().unwrap(); - - let agreem_keys_secret = Secrets { - id: None, - kid: diddoc.key_agreement.first().unwrap().clone(), - secret_material: agreem_keys_jwk, - }; - - match keystore.store(agreem_keys_secret).await { - Ok(_stored_connection) => { - tracing::info!("Successfully stored agreement keys.") - } - Err(error) => tracing::error!("Error storing agreement keys: {:?}", error), - } + // Acquire the CircuitBreaker lock + let mut cb = circuit_breaker.lock().await; + + // Wrap the process logic in the CircuitBreaker call + let result = cb + .call_async(|| { + let state = state.clone(); + let sender_did = sender_did.clone(); + let mediator_did = mediator_did.clone(); + + async move { + // Retrieve repository to connection entities + // Retrieve repository to connection entities + + // Retrieve repository to connection entities + + let AppStateRepository { + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + // If there is already mediation, send mediate deny + if let Some(_connection) = retry_async( + || { + let sender_did = sender_did.clone(); + let connection_repository = connection_repository.clone(); + async move { + connection_repository + .find_one_by(doc! { "client_did": sender_did }) + .await + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(1)), + ) + .await + .map_err(|_| MediationError::InternalServerError)? + { + tracing::info!("Sending mediate deny."); + return Ok(Some( + Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + MEDIATE_DENY_2_0.to_string(), + json!(MediationDeny { + id: format!("urn:uuid:{}", Uuid::new_v4()), + message_type: MEDIATE_DENY_2_0.to_string(), + ..Default::default() + }), + ) + .to(sender_did.clone()) + .from(mediator_did.clone()) + .finalize(), + )); + } else { + // Issue mediate grant response + tracing::info!("Sending mediate grant."); + // Create routing, store it and send mediation grant + let (routing_did, auth_keys, agreem_keys) = + generate_did_peer(state.public_domain.to_string()); + + let AppStateRepository { keystore, .. } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + let diddoc = retry_async( + || { + let did_resolver = state.did_resolver.clone(); + let routing_did = routing_did.clone(); + + async move { did_resolver.resolve(&routing_did).await.map_err(|_| ()) } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|err| { + tracing::error!("Failed to resolve DID: {:?}", err); + MediationError::InternalServerError + })? + .ok_or(MediationError::InternalServerError)?; + + let agreem_keys_jwk: Jwk = agreem_keys.try_into().unwrap(); + + let agreem_keys_secret = Secrets { + id: None, + kid: diddoc.key_agreement.get(0).unwrap().clone(), + secret_material: agreem_keys_jwk, + }; + + match keystore.store(agreem_keys_secret).await { + Ok(_stored_connection) => { + tracing::info!("Successfully stored agreement keys.") + } + Err(error) => tracing::error!("Error storing agreement keys: {:?}", error), + } - let auth_keys_jwk: Jwk = auth_keys.try_into().unwrap(); + let auth_keys_jwk: Jwk = auth_keys.try_into().unwrap(); + + let auth_keys_secret = Secrets { + id: None, + kid: diddoc.authentication.get(0).unwrap().clone(), + secret_material: auth_keys_jwk, + }; + + match keystore.store(auth_keys_secret).await { + Ok(_stored_connection) => { + tracing::info!("Successfully stored authentication keys.") + } + Err(error) => { + tracing::error!("Error storing authentication keys: {:?}", error) + } + } - let auth_keys_secret = Secrets { - id: None, - kid: diddoc.authentication.first().unwrap().clone(), - secret_material: auth_keys_jwk, - }; + let mediation_grant = create_mediation_grant(&routing_did); + + let new_connection = Connection { + id: None, + client_did: sender_did.to_string(), + mediator_did: mediator_did.to_string(), + keylist: vec!["".to_string()], + routing_did: routing_did, + }; + + // Use store_one to store the sample connection + match connection_repository.store(new_connection).await { + Ok(_stored_connection) => { + tracing::info!("Successfully stored connection: ") + } + Err(error) => tracing::error!("Error storing connection: {:?}", error), + } - match keystore.store(auth_keys_secret).await { - Ok(_stored_connection) => { - tracing::info!("Successfully stored authentication keys.") + Ok(Some( + Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + mediation_grant.message_type.clone(), + json!(mediation_grant), + ) + .to(sender_did.clone()) + .from(mediator_did.clone()) + .finalize(), + )) + } } - Err(error) => tracing::error!("Error storing authentication keys: {:?}", error), - } - - let mediation_grant = create_mediation_grant(&routing_did); - - let new_connection = Connection { - id: None, - client_did: sender_did.to_string(), - mediator_did: mediator_did.to_string(), - keylist: vec!["".to_string()], - routing_did, - }; + }) + .await; - // Use store_one to store the sample connection - match connection_repository.store(new_connection).await { - Ok(_stored_connection) => { - tracing::info!("Successfully stored connection: ") - } - Err(error) => tracing::error!("Error storing connection: {:?}", error), - } - - Ok(Some( - Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - mediation_grant.message_type.clone(), - json!(mediation_grant), - ) - .to(sender_did.clone()) - .from(mediator_did.clone()) - .finalize(), - )) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(MediationError::CircuitOpen), } } @@ -203,6 +255,7 @@ fn generate_did_peer(service_endpoint: String) -> (String, Ed25519KeyPair, X2551 pub(crate) async fn process_plain_keylist_update_message( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, MediationError> { // Extract message sender @@ -215,178 +268,247 @@ pub(crate) async fn process_plain_keylist_update_message( let keylist_update_body: KeylistUpdateBody = serde_json::from_value(message.body) .map_err(|_| MediationError::UnexpectedMessageFormat)?; - // Retrieve repository to connection entities - - let AppStateRepository { - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - // Find connection for this keylist update - - let connection = connection_repository - .find_one_by(doc! { "client_did": &sender }) - .await - .unwrap() - .ok_or(MediationError::UncoordinatedSender)?; - - // Prepare handles to relevant collections - - let mut updated_keylist = connection.keylist.clone(); - let updates = keylist_update_body.updates; - - // Closure to check if a specific key is duplicated across commands - - let key_is_duplicate = |recipient_did| { - updates - .iter() - .filter(|e| &e.recipient_did == recipient_did) - .count() - > 1 - }; - - // Perform updates to persist - - let confirmations: Vec<_> = updates - .iter() - .map(|update| KeylistUpdateConfirmation { - recipient_did: update.recipient_did.clone(), - action: update.action.clone(), - result: { - if let KeylistUpdateAction::Unknown(_) = &update.action { - KeylistUpdateResult::ClientError - } else if key_is_duplicate(&update.recipient_did) { - KeylistUpdateResult::ClientError - } else { - match connection - .keylist + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = state.clone(); + let sender = sender.clone(); + let keylist_update_body = keylist_update_body.clone(); + + async move { + let AppStateRepository { + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + // Find connection for this keylist update + + let connection = retry_async( + || { + let connection_repository = connection_repository.clone(); + let sender = sender.clone(); + + async move { + connection_repository + .find_one_by(doc! { "client_did": &sender }) + .await + .map_err(|_| ()) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|err| { + tracing::error!("Failed to find connection after retries: {:?}", err); + MediationError::InternalServerError + })? + .ok_or_else(|| MediationError::UncoordinatedSender)?; + + // Prepare handles to relevant collections + + let mut updated_keylist = connection.keylist.clone(); + let updates = keylist_update_body.updates; + + // Closure to check if a specific key is duplicated across commands + + let key_is_duplicate = |recipient_did| { + updates .iter() - .position(|x| x == &update.recipient_did) - { - Some(index) => match &update.action { - KeylistUpdateAction::Add => KeylistUpdateResult::NoChange, - KeylistUpdateAction::Remove => { - updated_keylist.swap_remove(index); - KeylistUpdateResult::Success + .filter(|e| &e.recipient_did == recipient_did) + .count() + > 1 + }; + + // Process keylist updates + let confirmations: Vec<_> = updates + .iter() + .map(|update| KeylistUpdateConfirmation { + recipient_did: update.recipient_did.clone(), + action: update.action.clone(), + result: { + if let KeylistUpdateAction::Unknown(_) = &update.action { + KeylistUpdateResult::ClientError + } else if key_is_duplicate(&update.recipient_did) { + KeylistUpdateResult::ClientError + } else { + match connection + .keylist + .iter() + .position(|x| x == &update.recipient_did) + { + Some(index) => match &update.action { + KeylistUpdateAction::Add => KeylistUpdateResult::NoChange, + KeylistUpdateAction::Remove => { + updated_keylist.swap_remove(index); + KeylistUpdateResult::Success + } + KeylistUpdateAction::Unknown(_) => unreachable!(), + }, + None => match &update.action { + KeylistUpdateAction::Add => { + updated_keylist.push(update.recipient_did.clone()); + KeylistUpdateResult::Success + } + KeylistUpdateAction::Remove => { + KeylistUpdateResult::NoChange + } + KeylistUpdateAction::Unknown(_) => unreachable!(), + }, + } } - KeylistUpdateAction::Unknown(_) => unreachable!(), }, - None => match &update.action { - KeylistUpdateAction::Add => { - updated_keylist.push(update.recipient_did.clone()); - KeylistUpdateResult::Success + }) + .collect(); + + let confirmations = match connection_repository + .update(Connection { + keylist: updated_keylist, + ..connection + }) + .await + { + Ok(_) => confirmations, + Err(_) => confirmations + .into_iter() + .map(|mut confirmation| { + if confirmation.result != KeylistUpdateResult::ClientError { + confirmation.result = KeylistUpdateResult::ServerError } - KeylistUpdateAction::Remove => KeylistUpdateResult::NoChange, - KeylistUpdateAction::Unknown(_) => unreachable!(), - }, - } - } - }, - }) - .collect(); - - // Persist updated keylist, update confirmations if server error - let confirmations = match connection_repository - .update(Connection { - keylist: updated_keylist, - ..connection + confirmation + }) + .collect(), + }; + + // Build response + + let mediator_did = &state.diddoc.id; + + Ok(Some( + Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + KEYLIST_UPDATE_RESPONSE_2_0.to_string(), + json!(KeylistUpdateResponseBody { + updated: confirmations + }), + ) + .to(sender) + .from(mediator_did.to_owned()) + .finalize(), + )) + } }) - .await - { - Ok(_) => confirmations, - Err(_) => confirmations - .into_iter() - .map(|mut confirmation| { - if confirmation.result != KeylistUpdateResult::ClientError { - confirmation.result = KeylistUpdateResult::ServerError - } - - confirmation - }) - .collect(), - }; - - // Build response - - let mediator_did = &state.diddoc.id; + .await; - Ok(Some( - Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - KEYLIST_UPDATE_RESPONSE_2_0.to_string(), - json!(KeylistUpdateResponseBody { - updated: confirmations - }), - ) - .to(sender) - .from(mediator_did.to_owned()) - .finalize(), - )) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(MediationError::CircuitOpen), + } } pub(crate) async fn process_plain_keylist_query_message( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, MediationError> { println!("Processing keylist query..."); let sender = message .from .expect("unpacking middleware failed to prevent anonymous senders"); - let AppStateRepository { - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - let connection = connection_repository - .find_one_by(doc! { "client_did": &sender }) - .await - .unwrap() - .ok_or(MediationError::UncoordinatedSender)?; - - println!("keylist: {:?}", connection); - - let keylist_entries = connection - .keylist - .iter() - .map(|key| KeylistEntry { - recipient_did: key.clone(), - }) - .collect::>(); - - let body = KeylistBody { - keys: keylist_entries, - pagination: None, - }; + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = state.clone(); + let sender = sender.clone(); + + async move { + let AppStateRepository { + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + let connection = retry_async( + || { + let connection_repository = connection_repository.clone(); + let sender = sender.clone(); + + async move { + connection_repository + .find_one_by(doc! { "client_did": &sender }) + .await + .map_err(|_| ()) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|err| { + tracing::error!("Failed to find connection after retries: {:?}", err); + MediationError::InternalServerError + })? + .ok_or_else(|| MediationError::UncoordinatedSender)?; + + println!("keylist: {:?}", connection); + + let keylist_entries = connection + .keylist + .iter() + .map(|key| KeylistEntry { + recipient_did: key.clone(), + }) + .collect::>(); + + let body = KeylistBody { + keys: keylist_entries, + pagination: None, + }; + + let keylist_object = Keylist { + id: format!("urn:uuid:{}", Uuid::new_v4()), + message_type: KEYLIST_2_0.to_string(), + body: body, + additional_properties: None, + }; - let keylist_object = Keylist { - id: format!("urn:uuid:{}", Uuid::new_v4()), - message_type: KEYLIST_2_0.to_string(), - body, - additional_properties: None, - }; + let mediator_did = &state.diddoc.id; - let mediator_did = &state.diddoc.id; + let message = Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + KEYLIST_2_0.to_string(), + json!(keylist_object), + ) + .to(sender.clone()) + .from(mediator_did.clone()) + .finalize(); - let message = Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - KEYLIST_2_0.to_string(), - json!(keylist_object), - ) - .to(sender.clone()) - .from(mediator_did.clone()) - .finalize(); + println!("message: {:?}", message); - println!("message: {:?}", message); + Ok(Some(message)) + } + }) + .await; - Ok(Some(message)) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(MediationError::CircuitOpen), + } } #[cfg(test)] @@ -394,7 +516,8 @@ mod tests { use super::*; use shared::{ - repository::tests::MockConnectionRepository, utils::tests_utils::tests as global, + circuit_breaker, repository::tests::MockConnectionRepository, + utils::tests_utils::tests as global, }; #[allow(clippy::needless_update)] @@ -428,11 +551,17 @@ mod tests { .from(global::_edge_did()) .finalize(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + // Process request - let response = process_plain_keylist_query_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let response = process_plain_keylist_query_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, KEYLIST_2_0); assert_eq!(response.from.unwrap(), global::_mediator_did(&state)); @@ -452,10 +581,16 @@ mod tests { .from("did:example:uncoordinated_sender".to_string()) .finalize(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + // Process request - let err = process_plain_keylist_query_message(Arc::clone(&state), message) - .await - .unwrap_err(); + let err = process_plain_keylist_query_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap_err(); // Assert issued error for uncoordinated sender assert_eq!(err, MediationError::UncoordinatedSender,); } @@ -488,10 +623,16 @@ mod tests { // Process request - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); let response = response; // Assert metadata @@ -587,10 +728,16 @@ mod tests { // Process request - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); // Assert updates assert_eq!( @@ -650,11 +797,16 @@ mod tests { .finalize(); // Process request + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); // Assert updates assert_eq!( @@ -710,11 +862,16 @@ mod tests { .finalize(); // Process request + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); // Assert updates @@ -748,9 +905,15 @@ mod tests { .finalize(); // Process request - let err = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap_err(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let err = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap_err(); // Assert issued error assert_eq!(err, MediationError::UnexpectedMessageFormat,); @@ -794,9 +957,15 @@ mod tests { .finalize(); // Process request - let err = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap_err(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let err = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap_err(); // Assert issued error assert_eq!(err, MediationError::UncoordinatedSender,); diff --git a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs index b0693bc1..0a955ff1 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs @@ -3,8 +3,9 @@ use async_trait::async_trait; use axum::response::{IntoResponse, Response}; use didcomm::Message; use message_api::{MessageHandler, MessagePlugin, MessageRouter}; -use shared::state::AppState; -use std::sync::Arc; +use shared::{circuit_breaker::CircuitBreaker, state::AppState}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; /// Represents the routing protocol plugin. pub struct MediatorCoordinationProtocol; @@ -20,7 +21,12 @@ impl MessageHandler for MediateRequestHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::stateful::process_mediate_request(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::stateful::process_mediate_request(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -33,7 +39,12 @@ impl MessageHandler for KeylistUpdateHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::stateful::process_plain_keylist_update_message(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::stateful::process_plain_keylist_update_message(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -46,7 +57,12 @@ impl MessageHandler for KeylistQueryHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::stateful::process_plain_keylist_query_message(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::stateful::process_plain_keylist_query_message(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml b/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml index 9ef52784..1796def2 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml @@ -15,6 +15,7 @@ thiserror.workspace = true async-trait.workspace = true uuid = { workspace = true, features = ["v4"] } axum = { workspace = true, features = ["macros"] } +tokio = "1.27.0" [dev-dependencies] shared = { workspace = true, features = ["test-utils"] } diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/error.rs b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/error.rs index 4fd25b89..da67b383 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/error.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/error.rs @@ -14,6 +14,9 @@ pub(crate) enum PickupError { #[error("Malformed request. {0}")] MalformedRequest(String), + + #[error("Service unavailable")] + CircuitOpen, } impl IntoResponse for PickupError { @@ -22,6 +25,7 @@ impl IntoResponse for PickupError { PickupError::MissingSenderDID | PickupError::MalformedRequest(_) => { StatusCode::BAD_REQUEST } + PickupError::CircuitOpen => StatusCode::SERVICE_UNAVAILABLE, PickupError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, PickupError::MissingClientConnection => StatusCode::UNAUTHORIZED, }; diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs index 9b71840a..dc68ad4a 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs @@ -10,191 +10,302 @@ use didcomm::{Attachment, Message, MessageBuilder}; use mongodb::bson::{doc, oid::ObjectId}; use serde_json::Value; use shared::{ + circuit_breaker::CircuitBreaker, midlw::ensure_transport_return_route_is_decorated_all, repository::entity::{Connection, RoutedMessage}, + retry::{retry_async, RetryOptions}, state::{AppState, AppStateRepository}, }; -use std::{str::FromStr, sync::Arc}; +use std::{str::FromStr, sync::Arc, time::Duration}; +use tokio::sync::Mutex; use uuid::Uuid; // Process pickup status request pub(crate) async fn handle_status_request( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, PickupError> { // Validate the return_route header ensure_transport_return_route_is_decorated_all(&message) .map_err(|_| PickupError::MalformedRequest("Missing return_route header".to_owned()))?; let mediator_did = &state.diddoc.id; - let recipient_did = message - .body - .get("recipient_did") - .and_then(|val| val.as_str()); let sender_did = sender_did(&message)?; - let repository = repository(state.clone())?; - let connection = client_connection(&repository, sender_did).await?; - - let message_count = count_messages(repository, recipient_did, connection).await?; - - let id = Uuid::new_v4().urn().to_string(); - let response_builder: MessageBuilder = StatusResponse { - id: id.as_str(), - type_: STATUS_RESPONSE_3_0, - body: BodyStatusResponse { - recipient_did, - message_count, - live_delivery: Some(false), - ..Default::default() - }, - } - .into(); - - let response = response_builder - .to(sender_did.to_owned()) - .from(mediator_did.to_owned()) - .finalize(); + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = state.clone(); + let message = message.clone(); + async move { + let recipient_did = message + .body + .get("recipient_did") + .and_then(|val| val.as_str()); + + let repository = repository(state.clone())?; + + let connection = retry_async( + || { + let repository = repository.clone(); + async move { client_connection(&repository, sender_did).await } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|_| { + PickupError::InternalError("Failed to retrieve client connection".to_owned()) + })?; + + // Pass `recipient_did` to count_messages, allowing it to handle `None` + let message_count = count_messages(repository, recipient_did, connection).await?; + + let id = Uuid::new_v4().urn().to_string(); + let response_builder: MessageBuilder = StatusResponse { + id: id.as_str(), + type_: STATUS_RESPONSE_3_0, + body: BodyStatusResponse { + recipient_did: recipient_did.to_owned(), + message_count, + live_delivery: Some(false), + ..Default::default() + }, + } + .into(); + + let response = response_builder + .to(sender_did.to_owned()) + .from(mediator_did.to_owned()) + .finalize(); + + Ok(Some(response)) + } + }) + .await; - Ok(Some(response)) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), + } } // Process pickup delivery request pub(crate) async fn handle_delivery_request( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, PickupError> { // Validate the return_route header ensure_transport_return_route_is_decorated_all(&message) .map_err(|_| PickupError::MalformedRequest("Missing return_route header".to_owned()))?; let mediator_did = &state.diddoc.id; - let recipient_did = message - .body - .get("recipient_did") - .and_then(|val| val.as_str()); - let sender_did = sender_did(&message)?; - - // Get the messages limit - let limit = message - .body - .get("limit") - .and_then(Value::as_u64) - .ok_or_else(|| PickupError::MalformedRequest("Invalid \"limit\" specifier".to_owned()))?; - - let repository = repository(state.clone())?; - let connection = client_connection(&repository, sender_did).await?; - - let messages = messages(repository, recipient_did, connection, limit as usize).await?; - let id = Uuid::new_v4().urn().to_string(); - - let response_builder: MessageBuilder = if messages.is_empty() { - StatusResponse { - id: id.as_str(), - type_: STATUS_RESPONSE_3_0, - body: BodyStatusResponse { - recipient_did, - message_count: 0, - live_delivery: Some(false), - ..Default::default() - }, - } - .into() - } else { - let mut attachments: Vec = Vec::with_capacity(messages.len()); - for message in messages { - let attached = Attachment::json(message.message) - .id(message.id.map(|id| id.to_string()).ok_or_else(|| { - PickupError::InternalError( - "Failed to load requested messages. Please try again later.".to_owned(), - ) - })?) - .finalize(); - - attachments.push(attached); - } - - DeliveryResponse { - id: id.as_str(), - thid: id.as_str(), - type_: MESSAGE_DELIVERY_3_0, - body: BodyDeliveryResponse { recipient_did }, - attachments, - } - .into() - }; + let sender_did = sender_did(&message)?; - let response = response_builder - .to(sender_did.to_owned()) - .from(mediator_did.to_owned()) - .finalize(); + let message_body = message.body.clone(); + + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = state.clone(); + let message_body = message_body.clone(); + async move { + let recipient_did = message_body.get("recipient_did").and_then(Value::as_str); + + let limit = retry_async( + || { + let message_body = message_body.clone(); + async move { + message_body + .get("limit") + .and_then(Value::as_u64) + .ok_or_else(|| { + PickupError::MalformedRequest( + "Invalid \"limit\" specifier".to_owned(), + ) + }) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await?; + + let repository = repository(state.clone())?; + let connection = retry_async( + || { + let repository = repository.clone(); + async move { client_connection(&repository, sender_did).await } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|_| { + PickupError::InternalError("Failed to retrieve client connection".to_owned()) + })?; + let messages = + messages(repository, recipient_did, connection, limit as usize).await?; + + let response_builder: MessageBuilder; + let id = Uuid::new_v4().urn().to_string(); + if messages.is_empty() { + response_builder = StatusResponse { + id: id.as_str(), + type_: STATUS_RESPONSE_3_0, + body: BodyStatusResponse { + recipient_did, + message_count: 0, + live_delivery: Some(false), + ..Default::default() + }, + } + .into(); + } else { + let mut attachments: Vec = Vec::with_capacity(messages.len()); + + for message in messages { + let attached = Attachment::json(message.message) + .id(message.id.map(|id| id.to_string()).ok_or_else(|| { + PickupError::InternalError( + "Failed to load requested messages. Please try again later." + .to_owned(), + ) + })?) + .finalize(); + + attachments.push(attached); + } + + response_builder = DeliveryResponse { + id: id.as_str(), + thid: id.as_str(), + type_: MESSAGE_DELIVERY_3_0, + body: BodyDeliveryResponse { recipient_did }, + attachments, + } + .into(); + } + + let response = response_builder + .to(sender_did.to_owned()) + .from(mediator_did.to_owned()) + .finalize(); + + Ok(Some(response)) + } + }) + .await; - Ok(Some(response)) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), + } } // Process pickup messages acknowledgement pub(crate) async fn handle_message_acknowledgement( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, PickupError> { // Validate the return_route header ensure_transport_return_route_is_decorated_all(&message) .map_err(|_| PickupError::MalformedRequest("Missing return_route header".to_owned()))?; - let mediator_did = &state.diddoc.id; - let repository = repository(state.clone())?; - let sender_did = sender_did(&message)?; - let connection = client_connection(&repository, sender_did).await?; - - // Get the message id list - let message_id_list = message - .body - .get("message_id_list") - .and_then(|v| v.as_array()) - .map(|a| a.iter().filter_map(|v| v.as_str()).collect::>()) - .ok_or_else(|| { - PickupError::MalformedRequest("Invalid \"message_id_list\" specifier".to_owned()) - })?; - - for id in message_id_list { - let msg_id = ObjectId::from_str(id); - if msg_id.is_err() { - return Err(PickupError::MalformedRequest(format!( - "Invalid message id: {id}" - ))); - } - repository - .message_repository - .delete_one(msg_id.unwrap()) - .await - .map_err(|_| { - PickupError::InternalError( - "Failed to process the request. Please try again later.".to_owned(), + // Acquire the CircuitBreaker lock + let mut cb = circuit_breaker.lock().await; + + // Wrap the message acknowledgement logic in the CircuitBreaker call + let result = cb + .call_async(|| { + let state = state.clone(); + let message = message.clone(); + async move { + let mediator_did = &state.diddoc.id; + let repository = repository(state.clone())?; + let sender_did = sender_did(&message)?; + let connection = client_connection(&repository, sender_did).await?; + + // Get the message ID list + let message_id_list = message + .body + .get("message_id_list") + .and_then(|v| v.as_array()) + .map(|a| a.iter().filter_map(|v| v.as_str()).collect::>()) + .ok_or_else(|| { + PickupError::MalformedRequest( + "Invalid \"message_id_list\" specifier".to_owned(), + ) + })?; + + for id in message_id_list { + let msg_id = ObjectId::from_str(id).map_err(|_| { + PickupError::MalformedRequest(format!("Invalid message id: {id}")) + })?; + + retry_async( + || { + let message_repository = repository.message_repository.clone(); + let msg_id = msg_id.clone(); + + async move { message_repository.delete_one(msg_id).await.map_err(|_| ()) } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), ) - })?; - } + .await + .map_err(|_| { + PickupError::InternalError( + "Failed to process the request. Please try again later.".to_owned(), + ) + })?; + } + + let message_count = count_messages(repository, None, connection).await?; + + let id = Uuid::new_v4().urn().to_string(); + let response_builder: MessageBuilder = StatusResponse { + id: id.as_str(), + type_: STATUS_RESPONSE_3_0, + body: BodyStatusResponse { + message_count, + live_delivery: Some(false), + ..Default::default() + }, + } + .into(); + + let response = response_builder + .to(sender_did.to_owned()) + .from(mediator_did.to_owned()) + .finalize(); + + Ok(Some(response)) + } + }) + .await; - let message_count = count_messages(repository, None, connection).await?; - - let id = Uuid::new_v4().urn().to_string(); - let response_builder: MessageBuilder = StatusResponse { - id: id.as_str(), - type_: STATUS_RESPONSE_3_0, - body: BodyStatusResponse { - message_count, - live_delivery: Some(false), - ..Default::default() - }, + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), } - .into(); - - let response = response_builder - .to(sender_did.to_owned()) - .from(mediator_did.to_owned()) - .finalize(); - - Ok(Some(response)) } // Process live delivery change request @@ -242,6 +353,7 @@ async fn count_messages( repository: AppStateRepository, recipient_did: Option<&str>, connection: Connection, + // circuit_breaker: Arc>, ) -> Result { let recipients = recipients(recipient_did, &connection); @@ -263,6 +375,7 @@ async fn messages( recipient_did: Option<&str>, connection: Connection, limit: usize, + // circuit_breaker: Arc>, ) -> Result, PickupError> { let recipients = recipients(recipient_did, &connection); @@ -333,6 +446,7 @@ mod tests { }; use serde_json::json; use shared::{ + circuit_breaker, repository::tests::{MockConnectionRepository, MockMessagesRepository}, utils::tests_utils::tests as global, }; @@ -397,10 +511,15 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_status_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_status_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, STATUS_RESPONSE_3_0); assert_eq!(response.from.unwrap(), global::_mediator_did(&state)); @@ -426,10 +545,16 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_status_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let response = handle_status_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!( response.body, @@ -453,10 +578,15 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_status_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_status_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!( response.body, @@ -479,11 +609,13 @@ mod tests { .from("did:key:invalid".to_owned()) .finalize(); - let error = handle_status_request(state, invalid_request) + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let error = handle_status_request(state, invalid_request, Arc::new(circuit_breaker.into())) .await .unwrap_err(); - assert_eq!(error, PickupError::MissingClientConnection); + assert_eq!(error.to_string(), "Failed to retrieve client connection"); + // assert_eq!(error, PickupError::MissingClientConnection); } #[tokio::test] @@ -500,10 +632,15 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_delivery_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); let expected_attachments = vec![ Attachment::json(json!("test1")) @@ -545,7 +682,8 @@ mod tests { // When the specified recipient did is not in the keylist, // it should return a status response with a message count of 0 - let response = handle_delivery_request(state, request) + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request(state, request, Arc::new(circuit_breaker.into())) .await .unwrap() .expect("Response should not be None"); @@ -574,7 +712,9 @@ mod tests { // When the limit is set to 0, it should return all the messages in the queue // and since the recipient did is not specified, it should return the messages // for all the dids in the keylist for that sender connection - let response = handle_delivery_request(state, request) + + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request(state, request, Arc::new(circuit_breaker.into())) .await .unwrap() .expect("Response should not be None"); @@ -614,7 +754,9 @@ mod tests { // Since the recipient did is not specified, it should return the messages // for all the dids in the keylist for that sender connection (2 in this case) // The limit is set to 1 so it should return the first message in the queue - let response = handle_delivery_request(state, request) + + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request(state, request, Arc::new(circuit_breaker.into())) .await .unwrap() .expect("Response should not be None"); @@ -645,10 +787,13 @@ mod tests { .finalize(); // Should return 2 since these ids are not associated with any message - let response = handle_message_acknowledgement(state, request) - .await - .unwrap() - .expect("Response should not be None"); + + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = + handle_message_acknowledgement(state, request, Arc::new(circuit_breaker.into())) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, STATUS_RESPONSE_3_0); assert_eq!( @@ -673,10 +818,12 @@ mod tests { // Should return 1 since one id in the list is associated // to the first message in the queue and then will be deleted - let response = handle_message_acknowledgement(state, request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = + handle_message_acknowledgement(state, request, Arc::new(circuit_breaker.into())) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, STATUS_RESPONSE_3_0); assert_eq!( diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs index b91353b9..1167b64c 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs @@ -5,8 +5,9 @@ use async_trait::async_trait; use axum::response::{IntoResponse, Response}; use didcomm::Message; use message_api::{MessageHandler, MessagePlugin, MessageRouter}; -use shared::state::AppState; -use std::sync::Arc; +use shared::{circuit_breaker::CircuitBreaker, state::AppState}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; /// Represents the pickup protocol plugin. pub struct PickupProtocol; @@ -23,7 +24,12 @@ impl MessageHandler for StatusRequestHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::handle_status_request(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::handle_status_request(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -36,7 +42,12 @@ impl MessageHandler for DeliveryRequestHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::handle_delivery_request(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::handle_delivery_request(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -49,7 +60,12 @@ impl MessageHandler for MessageReceivedHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::handle_message_acknowledgement(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::handle_message_acknowledgement(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } diff --git a/crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs b/crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs new file mode 100644 index 00000000..2238cd16 --- /dev/null +++ b/crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs @@ -0,0 +1,90 @@ +use std::time::{Duration, Instant}; + +#[derive(Debug)] +enum State { + // The circuit breaker is closed and allowing requests + // to pass through + Closed, + // The circuit breaker is open and blocking requests + Open, + // The circuit breaker is half-open and allowing a limited + // number of requests to pass through + HalfOpen, +} + +pub struct CircuitBreaker { + state: State, + // The duration to wait before transitioning from the + // open state to the half-open state + trip_timeout: Duration, + // The maximum number of requests allowed through in + // the closed state + max_failures: usize, + // The number of consecutive failures in the closed + // state + consecutive_failures: usize, + // The time when the circuit breaker transitioned to the + // open state + opened_at: Option, +} + +impl CircuitBreaker { + pub fn new(max_failures: usize, trip_timeout: Duration) -> CircuitBreaker { + CircuitBreaker { + state: State::Closed, + max_failures, + trip_timeout, + consecutive_failures: 0, + opened_at: None, + } + } + + pub async fn call_async(&mut self, f: F) -> Option> + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + { + match self.state { + State::Closed => { + if self.consecutive_failures < self.max_failures { + let result = f().await; + if result.is_err() { + self.record_failure(); + } + Some(result) + } else { + self.opened_at = Some(Instant::now()); + self.state = State::Open; + self.consecutive_failures = 0; + None + } + } + State::Open => { + if let Some(opened_at) = self.opened_at { + if Instant::now().duration_since(opened_at) >= self.trip_timeout { + self.state = State::HalfOpen; + self.opened_at = None; + } + } + None + } + State::HalfOpen => { + let result = f().await; + if result.is_err() { + self.state = State::Open; + } else { + self.state = State::Closed; + } + Some(result) + } + } + } + + fn record_failure(&mut self) { + match self.state { + State::Closed => self.consecutive_failures += 1, + State::Open => (), + State::HalfOpen => self.consecutive_failures += 1, + } + } +} diff --git a/crates/web-plugins/didcomm-messaging/shared/src/lib.rs b/crates/web-plugins/didcomm-messaging/shared/src/lib.rs index 530e4728..77ffa0c8 100644 --- a/crates/web-plugins/didcomm-messaging/shared/src/lib.rs +++ b/crates/web-plugins/didcomm-messaging/shared/src/lib.rs @@ -1,5 +1,7 @@ +pub mod circuit_breaker; pub mod errors; pub mod midlw; pub mod repository; +pub mod retry; pub mod state; pub mod utils; diff --git a/crates/web-plugins/didcomm-messaging/shared/src/retry.rs b/crates/web-plugins/didcomm-messaging/shared/src/retry.rs new file mode 100644 index 00000000..3eab1a0a --- /dev/null +++ b/crates/web-plugins/didcomm-messaging/shared/src/retry.rs @@ -0,0 +1,75 @@ +use std::time::Duration; +use tokio::time::sleep; + +pub struct RetryOptions { + retries: usize, + fixed_backoff: Option, + exponential_backoff: Option, + max_delay: Option, +} + +impl RetryOptions { + pub fn new() -> Self { + Self { + retries: 3, + fixed_backoff: None, + exponential_backoff: None, + max_delay: None, + } + } + + pub fn retries(mut self, count: usize) -> Self { + self.retries = count; + self + } + + pub fn fixed_backoff(mut self, delay: Duration) -> Self { + self.fixed_backoff = Some(delay); + self + } + + pub fn exponential_backoff(mut self, initial_delay: Duration) -> Self { + self.exponential_backoff = Some(initial_delay); + self + } + + pub fn max_delay(mut self, delay: Duration) -> Self { + self.max_delay = Some(delay); + self + } +} + +pub async fn retry_async(mut operation: F, options: RetryOptions) -> Result +where + F: FnMut() -> Fut, + Fut: std::future::Future>, +{ + let RetryOptions { + retries, + fixed_backoff, + exponential_backoff, + max_delay, + } = options; + + let mut attempt = 0; + let mut delay = exponential_backoff.unwrap_or_default(); + let max_delay = max_delay.unwrap_or_else(|| Duration::from_secs(60)); // Default max delay of 60 seconds + + loop { + attempt += 1; + + match operation().await { + Ok(result) => return Ok(result), + Err(_err) if attempt <= retries => { + if let Some(fixed) = fixed_backoff { + sleep(fixed).await; + } else if delay > Duration::ZERO { + let next_delay = delay.min(max_delay); + sleep(next_delay).await; + delay = (delay * 2).min(max_delay); + } + } + Err(_err) => return Err(_err), + } + } +} diff --git a/crates/web-plugins/didcomm-messaging/shared/src/state.rs b/crates/web-plugins/didcomm-messaging/shared/src/state.rs index 17aa6b02..6f498b17 100644 --- a/crates/web-plugins/didcomm-messaging/shared/src/state.rs +++ b/crates/web-plugins/didcomm-messaging/shared/src/state.rs @@ -1,9 +1,10 @@ use database::Repository; use did_utils::didcore::Document; use keystore::Secrets; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use crate::{ + circuit_breaker::CircuitBreaker, repository::entity::{Connection, RoutedMessage}, utils::resolvers::{LocalDIDResolver, LocalSecretsResolver}, }; @@ -23,6 +24,8 @@ pub struct AppState { // Persistence layer pub repository: Option, + pub circuit_breaker: Arc, + // disclosed protocols `https://org.didcomm.com/{protocol-name}/{version}/{request-type}`` pub supported_protocols: Option>, } @@ -55,6 +58,7 @@ impl AppState { did_resolver, secrets_resolver, repository, + circuit_breaker: Arc::new(CircuitBreaker::new(3, Duration::from_secs(10))), supported_protocols: disclose_protocols, }) }