diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index b2315188388c..97947e4e1b3c 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -1,7 +1,7 @@ use super::{ AfterDisconnect, ConnectingState, DisconnectingState, ErrorState, EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, - TunnelStateTransition, TunnelStateWrapper, + TunnelStateTransition, }; use crate::{ firewall::FirewallPolicy, @@ -27,14 +27,6 @@ use super::connecting_state::TunnelCloseEvent; pub(crate) type TunnelEventsReceiver = Fuse)>>; -pub struct ConnectedStateBootstrap { - pub metadata: TunnelMetadata, - pub tunnel_events: TunnelEventsReceiver, - pub tunnel_parameters: TunnelParameters, - pub tunnel_close_event: TunnelCloseEvent, - pub tunnel_close_tx: oneshot::Sender<()>, -} - /// The tunnel is up and working. pub struct ConnectedState { metadata: TunnelMetadata, @@ -45,13 +37,47 @@ pub struct ConnectedState { } impl ConnectedState { - fn from(bootstrap: ConnectedStateBootstrap) -> Self { - ConnectedState { - metadata: bootstrap.metadata, - tunnel_events: bootstrap.tunnel_events, - tunnel_parameters: bootstrap.tunnel_parameters, - tunnel_close_event: bootstrap.tunnel_close_event, - tunnel_close_tx: bootstrap.tunnel_close_tx, + #[cfg_attr(target_os = "android", allow(unused_variables))] + pub fn enter( + shared_values: &mut SharedTunnelStateValues, + metadata: TunnelMetadata, + tunnel_events: TunnelEventsReceiver, + tunnel_parameters: TunnelParameters, + tunnel_close_event: TunnelCloseEvent, + tunnel_close_tx: oneshot::Sender<()>, + ) -> (Box, TunnelStateTransition) { + let connected_state = ConnectedState { + metadata, + tunnel_events, + tunnel_parameters, + tunnel_close_event, + tunnel_close_tx, + }; + + let tunnel_interface = Some(connected_state.metadata.interface.clone()); + let tunnel_endpoint = talpid_types::net::TunnelEndpoint { + tunnel_interface, + ..connected_state.tunnel_parameters.get_tunnel_endpoint() + }; + + if let Err(error) = connected_state.set_firewall_policy(shared_values) { + DisconnectingState::enter( + connected_state.tunnel_close_tx, + connected_state.tunnel_close_event, + AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), + ) + } else if let Err(error) = connected_state.set_dns(shared_values) { + log::error!("{}", error.display_chain_with_msg("Failed to set DNS")); + DisconnectingState::enter( + connected_state.tunnel_close_tx, + connected_state.tunnel_close_event, + AfterDisconnect::Block(ErrorStateCause::SetDnsError), + ) + } else { + ( + Box::new(connected_state), + TunnelStateTransition::Connected(tunnel_endpoint), + ) } } @@ -173,17 +199,14 @@ impl ConnectedState { Self::reset_routes(shared_values); EventConsequence::NewState(DisconnectingState::enter( - shared_values, - ( - self.tunnel_close_tx, - self.tunnel_close_event, - after_disconnect, - ), + self.tunnel_close_tx, + self.tunnel_close_event, + after_disconnect, )) } fn handle_commands( - self, + self: Box, command: Option, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { @@ -199,7 +222,7 @@ impl ConnectedState { if cfg!(target_os = "android") { self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } else { - SameState(self.into()) + SameState(self) } } Err(error) => self.disconnect( @@ -212,7 +235,7 @@ impl ConnectedState { Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { shared_values.allowed_endpoint = endpoint; let _ = tx.send(()); - SameState(self.into()) + SameState(self) } Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) { Ok(true) => { @@ -227,7 +250,7 @@ impl ConnectedState { #[cfg(target_os = "android")] Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), #[cfg(not(target_os = "android"))] - Ok(()) => SameState(self.into()), + Ok(()) => SameState(self), Err(error) => { log::error!("{}", error.display_chain_with_msg("Failed to set DNS")); self.disconnect( @@ -237,14 +260,14 @@ impl ConnectedState { } } } - Ok(false) => SameState(self.into()), + Ok(false) => SameState(self), Err(error_cause) => { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } }, Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; - SameState(self.into()) + SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; @@ -254,7 +277,7 @@ impl ConnectedState { AfterDisconnect::Block(ErrorStateCause::IsOffline), ) } else { - SameState(self.into()) + SameState(self) } } Some(TunnelCommand::Connect) => { @@ -269,18 +292,18 @@ impl ConnectedState { #[cfg(target_os = "android")] Some(TunnelCommand::BypassSocket(fd, done_tx)) => { shared_values.bypass_socket(fd, done_tx); - SameState(self.into()) + SameState(self) } #[cfg(windows)] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { shared_values.split_tunnel.set_paths(&paths, result_tx); - SameState(self.into()) + SameState(self) } } } fn handle_tunnel_events( - self, + self: Box, event: Option<(TunnelEvent, oneshot::Sender<()>)>, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { @@ -290,7 +313,7 @@ impl ConnectedState { Some((TunnelEvent::Down, _)) | None => { self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } - Some(_) => SameState(self.into()), + Some(_) => SameState(self), } } @@ -315,49 +338,8 @@ impl ConnectedState { } impl TunnelState for ConnectedState { - type Bootstrap = ConnectedStateBootstrap; - - #[cfg_attr(target_os = "android", allow(unused_variables))] - fn enter( - shared_values: &mut SharedTunnelStateValues, - bootstrap: Self::Bootstrap, - ) -> (TunnelStateWrapper, TunnelStateTransition) { - let connected_state = ConnectedState::from(bootstrap); - let tunnel_interface = Some(connected_state.metadata.interface.clone()); - let tunnel_endpoint = talpid_types::net::TunnelEndpoint { - tunnel_interface, - ..connected_state.tunnel_parameters.get_tunnel_endpoint() - }; - - if let Err(error) = connected_state.set_firewall_policy(shared_values) { - DisconnectingState::enter( - shared_values, - ( - connected_state.tunnel_close_tx, - connected_state.tunnel_close_event, - AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), - ), - ) - } else if let Err(error) = connected_state.set_dns(shared_values) { - log::error!("{}", error.display_chain_with_msg("Failed to set DNS")); - DisconnectingState::enter( - shared_values, - ( - connected_state.tunnel_close_tx, - connected_state.tunnel_close_event, - AfterDisconnect::Block(ErrorStateCause::SetDnsError), - ), - ) - } else { - ( - TunnelStateWrapper::from(connected_state), - TunnelStateTransition::Connected(tunnel_endpoint), - ) - } - } - fn handle_event( - mut self, + mut self: Box, runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 28731d617d6d..e9b652dd0cc6 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -1,7 +1,7 @@ use super::{ - AfterDisconnect, ConnectedState, ConnectedStateBootstrap, DisconnectingState, ErrorState, - EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, - TunnelState, TunnelStateTransition, TunnelStateWrapper, + AfterDisconnect, ConnectedState, DisconnectingState, ErrorState, EventConsequence, EventResult, + SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, TunnelState, + TunnelStateTransition, }; use crate::{ firewall::FirewallPolicy, @@ -53,6 +53,84 @@ pub struct ConnectingState { } impl ConnectingState { + pub fn enter( + shared_values: &mut SharedTunnelStateValues, + retry_attempt: u32, + ) -> (Box, TunnelStateTransition) { + if shared_values.is_offline { + // FIXME: Temporary: Nudge route manager to update the default interface + #[cfg(target_os = "macos")] + if let Ok(handle) = shared_values.route_manager.handle() { + log::debug!("Poking route manager to update default routes"); + let _ = handle.refresh_routes(); + } + return ErrorState::enter(shared_values, ErrorStateCause::IsOffline); + } + match shared_values.runtime.block_on( + shared_values + .tunnel_parameters_generator + .generate(retry_attempt), + ) { + Err(err) => { + ErrorState::enter(shared_values, ErrorStateCause::TunnelParameterError(err)) + } + Ok(tunnel_parameters) => { + #[cfg(windows)] + if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to reset addresses in split tunnel driver" + ) + ); + + return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError); + } + + if let Err(error) = Self::set_firewall_policy( + shared_values, + &tunnel_parameters, + &None, + AllowedTunnelTraffic::None, + ) { + ErrorState::enter( + shared_values, + ErrorStateCause::SetFirewallPolicyError(error), + ) + } else { + #[cfg(target_os = "android")] + { + if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 { + if let Err(error) = + { shared_values.tun_provider.lock().unwrap().create_tun() } + { + log::error!( + "{}", + error.display_chain_with_msg("Failed to recreate tun device") + ); + } + } + } + + let connecting_state = Self::start_tunnel( + shared_values.runtime.clone(), + tunnel_parameters, + &shared_values.log_dir, + &shared_values.resource_dir, + shared_values.tun_provider.clone(), + &shared_values.route_manager, + retry_attempt, + ); + let params = connecting_state.tunnel_parameters.clone(); + ( + Box::new(connecting_state), + TunnelStateTransition::Connecting(params.get_tunnel_endpoint()), + ) + } + } + } + } + fn set_firewall_policy( shared_values: &mut SharedTunnelStateValues, params: &TunnelParameters, @@ -249,16 +327,6 @@ impl ConnectingState { } } - fn into_connected_state_bootstrap(self, metadata: TunnelMetadata) -> ConnectedStateBootstrap { - ConnectedStateBootstrap { - metadata, - tunnel_events: self.tunnel_events, - tunnel_parameters: self.tunnel_parameters, - tunnel_close_event: self.tunnel_close_event, - tunnel_close_tx: self.tunnel_close_tx, - } - } - fn reset_routes( #[cfg(target_os = "windows")] shared_values: &SharedTunnelStateValues, #[cfg(not(target_os = "windows"))] shared_values: &mut SharedTunnelStateValues, @@ -286,16 +354,16 @@ impl ConnectingState { Self::reset_routes(shared_values); EventConsequence::NewState(DisconnectingState::enter( - shared_values, - ( - self.tunnel_close_tx, - self.tunnel_close_event, - after_disconnect, - ), + self.tunnel_close_tx, + self.tunnel_close_event, + after_disconnect, )) } - fn reset_firewall(self, shared_values: &mut SharedTunnelStateValues) -> EventConsequence { + fn reset_firewall( + self: Box, + shared_values: &mut SharedTunnelStateValues, + ) -> EventConsequence { match Self::set_firewall_policy( shared_values, &self.tunnel_parameters, @@ -306,7 +374,7 @@ impl ConnectingState { if cfg!(target_os = "android") { self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } else { - EventConsequence::SameState(self.into()) + EventConsequence::SameState(self) } } Err(error) => self.disconnect( @@ -317,7 +385,7 @@ impl ConnectingState { } fn handle_commands( - self, + self: Box, command: Option, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { @@ -348,17 +416,17 @@ impl ConnectingState { } } let _ = tx.send(()); - SameState(self.into()) + SameState(self) } Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) { #[cfg(target_os = "android")] Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - Ok(_) => SameState(self.into()), + Ok(_) => SameState(self), Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)), }, Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; - SameState(self.into()) + SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; @@ -368,7 +436,7 @@ impl ConnectingState { AfterDisconnect::Block(ErrorStateCause::IsOffline), ) } else { - SameState(self.into()) + SameState(self) } } Some(TunnelCommand::Connect) => { @@ -383,18 +451,18 @@ impl ConnectingState { #[cfg(target_os = "android")] Some(TunnelCommand::BypassSocket(fd, done_tx)) => { shared_values.bypass_socket(fd, done_tx); - SameState(self.into()) + SameState(self) } #[cfg(windows)] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { shared_values.split_tunnel.set_paths(&paths, result_tx); - SameState(self.into()) + SameState(self) } } } fn handle_tunnel_events( - mut self, + mut self: Box, event: Option<(tunnel::TunnelEvent, oneshot::Sender<()>)>, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { @@ -432,7 +500,7 @@ impl ConnectingState { &self.tunnel_metadata, self.allowed_tunnel_traffic.clone(), ) { - Ok(()) => SameState(self.into()), + Ok(()) => SameState(self), Err(error) => self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), @@ -441,7 +509,11 @@ impl ConnectingState { } Some((TunnelEvent::Up(metadata), _)) => NewState(ConnectedState::enter( shared_values, - self.into_connected_state_bootstrap(metadata), + metadata, + self.tunnel_events, + self.tunnel_parameters, + self.tunnel_close_event, + self.tunnel_close_tx, )), Some((TunnelEvent::Down, _)) => { // It is important to reset this before the tunnel device is down, @@ -450,7 +522,7 @@ impl ConnectingState { self.allowed_tunnel_traffic = INITIAL_ALLOWED_TUNNEL_TRAFFIC; self.tunnel_metadata = None; - SameState(self.into()) + SameState(self) } None => { // The channel was closed @@ -532,88 +604,8 @@ fn is_recoverable_routing_error(error: &talpid_routing::Error) -> bool { } impl TunnelState for ConnectingState { - type Bootstrap = u32; - - fn enter( - shared_values: &mut SharedTunnelStateValues, - retry_attempt: u32, - ) -> (TunnelStateWrapper, TunnelStateTransition) { - if shared_values.is_offline { - // FIXME: Temporary: Nudge route manager to update the default interface - #[cfg(target_os = "macos")] - if let Ok(handle) = shared_values.route_manager.handle() { - log::debug!("Poking route manager to update default routes"); - let _ = handle.refresh_routes(); - } - return ErrorState::enter(shared_values, ErrorStateCause::IsOffline); - } - match shared_values.runtime.block_on( - shared_values - .tunnel_parameters_generator - .generate(retry_attempt), - ) { - Err(err) => { - ErrorState::enter(shared_values, ErrorStateCause::TunnelParameterError(err)) - } - Ok(tunnel_parameters) => { - #[cfg(windows)] - if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to reset addresses in split tunnel driver" - ) - ); - - return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError); - } - - if let Err(error) = Self::set_firewall_policy( - shared_values, - &tunnel_parameters, - &None, - AllowedTunnelTraffic::None, - ) { - ErrorState::enter( - shared_values, - ErrorStateCause::SetFirewallPolicyError(error), - ) - } else { - #[cfg(target_os = "android")] - { - if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 { - if let Err(error) = - { shared_values.tun_provider.lock().unwrap().create_tun() } - { - log::error!( - "{}", - error.display_chain_with_msg("Failed to recreate tun device") - ); - } - } - } - - let connecting_state = Self::start_tunnel( - shared_values.runtime.clone(), - tunnel_parameters, - &shared_values.log_dir, - &shared_values.resource_dir, - shared_values.tun_provider.clone(), - &shared_values.route_manager, - retry_attempt, - ); - let params = connecting_state.tunnel_parameters.clone(); - ( - TunnelStateWrapper::from(connecting_state), - TunnelStateTransition::Connecting(params.get_tunnel_endpoint()), - ) - } - } - } - } - fn handle_event( - mut self, + mut self: Box, runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 92a05862451b..1dab1e5f923c 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -1,6 +1,6 @@ use super::{ ConnectingState, ErrorState, EventConsequence, SharedTunnelStateValues, TunnelCommand, - TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, + TunnelCommandReceiver, TunnelState, TunnelStateTransition, }; #[cfg(target_os = "macos")] use crate::dns; @@ -13,9 +13,42 @@ use talpid_types::tunnel::ErrorStateCause; use talpid_types::ErrorExt; /// No tunnel is running. -pub struct DisconnectedState; +pub struct DisconnectedState(()); impl DisconnectedState { + pub fn enter( + shared_values: &mut SharedTunnelStateValues, + should_reset_firewall: bool, + ) -> (Box, TunnelStateTransition) { + #[cfg(target_os = "macos")] + if shared_values.block_when_disconnected { + if let Err(err) = Self::setup_local_dns_config(shared_values) { + log::error!( + "{}", + err.display_chain_with_msg("Failed to start filtering resolver:") + ); + } + } else if let Err(error) = shared_values.dns_monitor.reset() { + log::error!( + "{}", + error.display_chain_with_msg("Unable to disable filtering resolver") + ); + } + + #[cfg(windows)] + Self::register_split_tunnel_addresses(shared_values, should_reset_firewall); + Self::set_firewall_policy(shared_values, should_reset_firewall); + #[cfg(target_os = "linux")] + shared_values.reset_connectivity_check(); + #[cfg(target_os = "android")] + shared_values.tun_provider.lock().unwrap().close_tun(); + + ( + Box::new(DisconnectedState(())), + TunnelStateTransition::Disconnected, + ) + } + fn set_firewall_policy( shared_values: &mut SharedTunnelStateValues, should_reset_firewall: bool, @@ -86,43 +119,8 @@ impl DisconnectedState { } impl TunnelState for DisconnectedState { - type Bootstrap = bool; - - fn enter( - shared_values: &mut SharedTunnelStateValues, - should_reset_firewall: Self::Bootstrap, - ) -> (TunnelStateWrapper, TunnelStateTransition) { - #[cfg(target_os = "macos")] - if shared_values.block_when_disconnected { - if let Err(err) = Self::setup_local_dns_config(shared_values) { - log::error!( - "{}", - err.display_chain_with_msg("Failed to start filtering resolver:") - ); - } - } else if let Err(error) = shared_values.dns_monitor.reset() { - log::error!( - "{}", - error.display_chain_with_msg("Unable to disable filtering resolver") - ); - } - - #[cfg(windows)] - Self::register_split_tunnel_addresses(shared_values, should_reset_firewall); - Self::set_firewall_policy(shared_values, should_reset_firewall); - #[cfg(target_os = "linux")] - shared_values.reset_connectivity_check(); - #[cfg(target_os = "android")] - shared_values.tun_provider.lock().unwrap().close_tun(); - - ( - TunnelStateWrapper::from(DisconnectedState), - TunnelStateTransition::Disconnected, - ) - } - fn handle_event( - self, + self: Box, runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, @@ -140,7 +138,7 @@ impl TunnelState for DisconnectedState { Self::set_firewall_policy(shared_values, false); } - SameState(self.into()) + SameState(self) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { if shared_values.allowed_endpoint != endpoint { @@ -148,7 +146,7 @@ impl TunnelState for DisconnectedState { Self::set_firewall_policy(shared_values, false); } let _ = tx.send(()); - SameState(self.into()) + SameState(self) } Some(TunnelCommand::Dns(servers)) => { // Same situation as allow LAN above. @@ -156,7 +154,7 @@ impl TunnelState for DisconnectedState { .set_dns_servers(servers) .expect("Failed to reconnect after changing custom DNS servers"); - SameState(self.into()) + SameState(self) } Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { if shared_values.block_when_disconnected != block_when_disconnected { @@ -180,11 +178,11 @@ impl TunnelState for DisconnectedState { Self::reset_dns(shared_values); } } - SameState(self.into()) + SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; - SameState(self.into()) + SameState(self) } Some(TunnelCommand::Connect) => NewState(ConnectingState::enter(shared_values, 0)), Some(TunnelCommand::Block(reason)) => { @@ -194,18 +192,18 @@ impl TunnelState for DisconnectedState { #[cfg(target_os = "android")] Some(TunnelCommand::BypassSocket(fd, done_tx)) => { shared_values.bypass_socket(fd, done_tx); - SameState(self.into()) + SameState(self) } #[cfg(windows)] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { shared_values.split_tunnel.set_paths(&paths, result_tx); - SameState(self.into()) + SameState(self) } None => { Self::reset_dns(shared_values); Finished } - Some(_) => SameState(self.into()), + Some(_) => SameState(self), } } } diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 954de76034c9..85f70c5842f5 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -1,7 +1,7 @@ use super::{ connecting_state::TunnelCloseEvent, ConnectingState, DisconnectedState, ErrorState, EventConsequence, EventResult, SharedTunnelStateValues, TunnelCommand, TunnelCommandReceiver, - TunnelState, TunnelStateTransition, TunnelStateWrapper, + TunnelState, TunnelStateTransition, }; use futures::{channel::oneshot, future::FusedFuture, StreamExt}; use talpid_types::tunnel::{ActionAfterDisconnect, ErrorStateCause}; @@ -14,8 +14,25 @@ pub struct DisconnectingState { } impl DisconnectingState { + pub fn enter( + tunnel_close_tx: oneshot::Sender<()>, + tunnel_close_event: TunnelCloseEvent, + after_disconnect: AfterDisconnect, + ) -> (Box, TunnelStateTransition) { + let _ = tunnel_close_tx.send(()); + let action_after_disconnect = after_disconnect.action(); + + ( + Box::new(DisconnectingState { + tunnel_close_event, + after_disconnect, + }), + TunnelStateTransition::Disconnecting(action_after_disconnect), + ) + } + fn handle_commands( - mut self, + mut self: Box, command: Option, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence { @@ -141,14 +158,14 @@ impl DisconnectingState { }, }; - EventConsequence::SameState(self.into()) + EventConsequence::SameState(self) } fn after_disconnect( self, block_reason: Option, shared_values: &mut SharedTunnelStateValues, - ) -> (TunnelStateWrapper, TunnelStateTransition) { + ) -> (Box, TunnelStateTransition) { if let Some(reason) = block_reason { return ErrorState::enter(shared_values, reason); } @@ -164,26 +181,8 @@ impl DisconnectingState { } impl TunnelState for DisconnectingState { - type Bootstrap = (oneshot::Sender<()>, TunnelCloseEvent, AfterDisconnect); - - fn enter( - _: &mut SharedTunnelStateValues, - (tunnel_close_tx, tunnel_close_event, after_disconnect): Self::Bootstrap, - ) -> (TunnelStateWrapper, TunnelStateTransition) { - let _ = tunnel_close_tx.send(()); - let action_after_disconnect = after_disconnect.action(); - - ( - TunnelStateWrapper::from(DisconnectingState { - tunnel_close_event, - after_disconnect, - }), - TunnelStateTransition::Disconnecting(action_after_disconnect), - ) - } - fn handle_event( - mut self, + mut self: Box, runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 7fe95c9f673e..30b933653df5 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -1,6 +1,6 @@ use super::{ ConnectingState, DisconnectedState, EventConsequence, SharedTunnelStateValues, TunnelCommand, - TunnelCommandReceiver, TunnelState, TunnelStateTransition, TunnelStateWrapper, + TunnelCommandReceiver, TunnelState, TunnelStateTransition, }; use crate::firewall::FirewallPolicy; use futures::StreamExt; @@ -17,6 +17,56 @@ pub struct ErrorState { } impl ErrorState { + pub fn enter( + shared_values: &mut SharedTunnelStateValues, + block_reason: ErrorStateCause, + ) -> (Box, TunnelStateTransition) { + #[cfg(windows)] + if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to register addresses with split tunnel driver" + ) + ); + } + + #[cfg(target_os = "macos")] + if !block_reason.prevents_filtering_resolver() { + if let Err(err) = shared_values + .dns_monitor + .set("lo", &[Ipv4Addr::LOCALHOST.into()]) + { + log::error!( + "{}", + err.display_chain_with_msg( + "Failed to configure system to use filtering resolver" + ) + ); + return Self::enter(shared_values, ErrorStateCause::SetDnsError); + } + }; + + #[cfg(not(target_os = "android"))] + let block_failure = Self::set_firewall_policy(shared_values).err(); + + #[cfg(target_os = "android")] + let block_failure = if !Self::create_blocking_tun(shared_values) { + Some(FirewallPolicyError::Generic) + } else { + None + }; + ( + Box::new(ErrorState { + block_reason: block_reason.clone(), + }), + TunnelStateTransition::Error(talpid_tunnel::ErrorState::new( + block_reason, + block_failure, + )), + ) + } + fn set_firewall_policy( shared_values: &mut SharedTunnelStateValues, ) -> Result<(), FirewallPolicyError> { @@ -78,61 +128,9 @@ impl ErrorState { } impl TunnelState for ErrorState { - type Bootstrap = ErrorStateCause; - - fn enter( - shared_values: &mut SharedTunnelStateValues, - block_reason: Self::Bootstrap, - ) -> (TunnelStateWrapper, TunnelStateTransition) { - #[cfg(windows)] - if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to register addresses with split tunnel driver" - ) - ); - } - - #[cfg(target_os = "macos")] - if !block_reason.prevents_filtering_resolver() { - if let Err(err) = shared_values - .dns_monitor - .set("lo", &[Ipv4Addr::LOCALHOST.into()]) - { - log::error!( - "{}", - err.display_chain_with_msg( - "Failed to configure system to use filtering resolver" - ) - ); - return Self::enter(shared_values, ErrorStateCause::SetDnsError); - } - }; - - #[cfg(not(target_os = "android"))] - let block_failure = Self::set_firewall_policy(shared_values).err(); - - #[cfg(target_os = "android")] - let block_failure = if !Self::create_blocking_tun(shared_values) { - Some(FirewallPolicyError::Generic) - } else { - None - }; - ( - TunnelStateWrapper::from(ErrorState { - block_reason: block_reason.clone(), - }), - TunnelStateTransition::Error(talpid_tunnel::ErrorState::new( - block_reason, - block_failure, - )), - ) - } - #[cfg_attr(not(target_os = "macos"), allow(unused_mut))] fn handle_event( - self, + self: Box, runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, @@ -145,7 +143,7 @@ impl TunnelState for ErrorState { NewState(Self::enter(shared_values, error_state_cause)) } else { let _ = Self::set_firewall_policy(shared_values); - SameState(self.into()) + SameState(self) } } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -163,18 +161,18 @@ impl TunnelState for ErrorState { } } let _ = tx.send(()); - SameState(self.into()) + SameState(self) } Some(TunnelCommand::Dns(servers)) => { if let Err(error_state_cause) = shared_values.set_dns_servers(servers) { NewState(Self::enter(shared_values, error_state_cause)) } else { - SameState(self.into()) + SameState(self) } } Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { shared_values.block_when_disconnected = block_when_disconnected; - SameState(self.into()) + SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { shared_values.is_offline = is_offline; @@ -182,7 +180,7 @@ impl TunnelState for ErrorState { Self::reset_dns(shared_values); NewState(ConnectingState::enter(shared_values, 0)) } else { - SameState(self.into()) + SameState(self) } } Some(TunnelCommand::Connect) => { @@ -202,12 +200,12 @@ impl TunnelState for ErrorState { #[cfg(target_os = "android")] Some(TunnelCommand::BypassSocket(fd, done_tx)) => { shared_values.bypass_socket(fd, done_tx); - SameState(self.into()) + SameState(self) } #[cfg(windows)] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { shared_values.split_tunnel.set_paths(&paths, result_tx); - SameState(self.into()) + SameState(self) } } } diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index f4c58b849c96..12bc4cfc86fc 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -5,7 +5,7 @@ mod disconnecting_state; mod error_state; use self::{ - connected_state::{ConnectedState, ConnectedStateBootstrap}, + connected_state::ConnectedState, connecting_state::ConnectingState, disconnected_state::DisconnectedState, disconnecting_state::{AfterDisconnect, DisconnectingState}, @@ -232,7 +232,7 @@ enum EventResult { /// to. Every time it successfully advances the state machine a `TunnelStateTransition` is emitted /// by the stream. struct TunnelStateMachine { - current_state: Option, + current_state: Option>, commands: TunnelCommandReceiver, shared_values: SharedTunnelStateValues, } @@ -389,9 +389,8 @@ impl TunnelStateMachine { let runtime = self.shared_values.runtime.clone(); - while let Some(state_wrapper) = self.current_state.take() { - match state_wrapper.handle_event(&runtime, &mut self.commands, &mut self.shared_values) - { + while let Some(state) = self.current_state.take() { + match state.handle_event(&runtime, &mut self.commands, &mut self.shared_values) { NewState((state, transition)) => { self.current_state = Some(state); @@ -557,28 +556,16 @@ impl SharedTunnelStateValues { /// Asynchronous result of an attempt to progress a state. enum EventConsequence { /// Transition to a new state. - NewState((TunnelStateWrapper, TunnelStateTransition)), + NewState((Box, TunnelStateTransition)), /// An event was received, but it was ignored by the state so no transition is performed. - SameState(TunnelStateWrapper), + SameState(Box), /// The state machine has finished its execution. Finished, } /// Trait that contains the method all states should implement to handle an event and advance the /// state machine. -trait TunnelState: Into + Sized { - /// Type representing extra information required for entering the state. - type Bootstrap; - - /// Constructor function. - /// - /// This is the state entry point. It attempts to enter the state, and may fail by entering an - /// error or fallback state instead. - fn enter( - shared_values: &mut SharedTunnelStateValues, - bootstrap: Self::Bootstrap, - ) -> (TunnelStateWrapper, TunnelStateTransition); - +trait TunnelState: Send { /// Main state function. /// /// This is state exit point. It consumes itself and returns the next state to advance to when @@ -590,56 +577,13 @@ trait TunnelState: Into + Sized { /// /// [`EventConsequence`]: enum.EventConsequence.html fn handle_event( - self, + self: Box, runtime: &tokio::runtime::Handle, commands: &mut TunnelCommandReceiver, shared_values: &mut SharedTunnelStateValues, ) -> EventConsequence; } -macro_rules! state_wrapper { - (enum $wrapper_name:ident { $($state_variant:ident($state_type:ident)),* $(,)* }) => { - /// Valid states of the tunnel. - /// - /// All implementations must implement `TunnelState` so that they can handle events and - /// commands in order to advance the state machine. - enum $wrapper_name { - $($state_variant($state_type),)* - } - - $(impl From<$state_type> for $wrapper_name { - fn from(state: $state_type) -> Self { - $wrapper_name::$state_variant(state) - } - })* - - impl $wrapper_name { - fn handle_event( - self, - runtime: &tokio::runtime::Handle, - commands: &mut TunnelCommandReceiver, - shared_values: &mut SharedTunnelStateValues, - ) -> EventConsequence { - match self { - $($wrapper_name::$state_variant(state) => { - state.handle_event(runtime, commands, shared_values) - })* - } - } - } - } -} - -state_wrapper! { - enum TunnelStateWrapper { - Disconnected(DisconnectedState), - Connecting(ConnectingState), - Connected(ConnectedState), - Disconnecting(DisconnectingState), - Error(ErrorState), - } -} - /// Handle used to control the tunnel state machine. pub struct TunnelStateMachineHandle { command_tx: Arc>,