From 58efb1004f8ca762fe6aa95f541c6dc329ed790e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Em=C4=ABls?= Date: Fri, 13 Dec 2024 11:42:22 +0100 Subject: [PATCH 1/3] Add mockito --- Cargo.lock | 53 +++++++++++++++++++++++++++++++++++++++++- mullvad-api/Cargo.toml | 6 +++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index c85c6b6f08e7..599621943cda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,6 +172,16 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -592,6 +602,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "colored" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys 0.59.0", +] + [[package]] name = "combine" version = "4.6.7" @@ -1067,7 +1087,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9f0c14694cbd524c8720dd69b0e3179344f04ebb5f90f2e4a440c6ea3b2f1ee" dependencies = [ - "colored", + "colored 1.9.4", "log", ] @@ -2321,6 +2341,30 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "mockito" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "652cd6d169a36eaf9d1e6bce1a221130439a966d7f27858af66a33a66e9c4ee2" +dependencies = [ + "assert-json-diff", + "bytes", + "colored 2.2.0", + "futures-util", + "http 1.1.0", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "rand 0.8.5", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "mullvad-api" version = "0.0.0" @@ -2336,6 +2380,7 @@ dependencies = [ "ipnetwork", "libc", "log", + "mockito", "mullvad-encrypted-dns-proxy", "mullvad-fs", "mullvad-types", @@ -4114,6 +4159,12 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "similar" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" + [[package]] name = "simple-signal" version = "1.1.1" diff --git a/mullvad-api/Cargo.toml b/mullvad-api/Cargo.toml index a822593600e0..fc9d7d899b37 100644 --- a/mullvad-api/Cargo.toml +++ b/mullvad-api/Cargo.toml @@ -45,6 +45,7 @@ shadowsocks = { workspace = true, features = [ "stream-cipher" ] } [dev-dependencies] talpid-time = { path = "../talpid-time", features = ["test"] } tokio = { workspace = true, features = ["test-util", "time"] } +mockito = "1.6.1" [build-dependencies] cbindgen = { version = "0.24.3", default-features = false } @@ -55,3 +56,8 @@ uuid = { version = "1.4.1", features = ["v4"] } [lib] crate-type = [ "rlib", "staticlib" ] bench = false + +[[test]] +name = "ffi" +# required-features = [ "api-override" ] +features = [ "api-override" ] From 3093408a057020fcd912976b892fbb6bc26e6293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Em=C4=ABls?= Date: Mon, 9 Dec 2024 10:59:51 +0100 Subject: [PATCH 2/3] Remove global API endpoint --- .../dataproxy/MullvadProblemReport.kt | 20 ++- .../net/mullvad/mullvadvpn/di/UiModule.kt | 3 +- .../lib/endpoint/ApiEndpointOverride.kt | 1 - .../mullvadvpn/test/mockapi/MockApiTest.kt | 7 +- ios/MullvadVPNUITests/MullvadApi.swift | 3 +- mullvad-api/Cargo.toml | 9 +- mullvad-api/include/mullvad-api.h | 11 +- mullvad-api/src/address_cache.rs | 29 ++-- mullvad-api/src/bin/relay_list.rs | 10 +- mullvad-api/src/ffi/error.rs | 8 + mullvad-api/src/ffi/mod.rs | 106 +++++++++--- mullvad-api/src/https_client_with_sni.rs | 28 +++- mullvad-api/src/lib.rs | 152 ++++++++---------- mullvad-api/src/rest.rs | 4 +- mullvad-daemon/src/api.rs | 13 +- mullvad-daemon/src/api_address_updater.rs | 10 +- mullvad-daemon/src/lib.rs | 7 + mullvad-daemon/src/main.rs | 2 + mullvad-jni/src/api.rs | 13 +- mullvad-jni/src/lib.rs | 13 +- mullvad-jni/src/problem_report.rs | 5 + mullvad-problem-report/src/lib.rs | 6 +- mullvad-problem-report/src/main.rs | 15 +- mullvad-setup/src/main.rs | 9 +- test/test-manager/src/tests/account.rs | 15 +- 25 files changed, 311 insertions(+), 188 deletions(-) diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt index f4a0777e3f0c..3b4a460fea75 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/dataproxy/MullvadProblemReport.kt @@ -5,6 +5,9 @@ import java.io.File import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointFromIntentHolder +import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride +import net.mullvad.mullvadvpn.service.BuildConfig const val PROBLEM_REPORT_LOGS_FILE = "problem_report.txt" @@ -21,7 +24,12 @@ sealed interface SendProblemReportResult { data class UserReport(val email: String?, val description: String) -class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher = Dispatchers.IO) { +class MullvadProblemReport( + context: Context, + private val apiEndpointOverride: ApiEndpointOverride?, + private val apiEndpointFromIntentHolder: ApiEndpointFromIntentHolder, + val dispatcher: CoroutineDispatcher = Dispatchers.IO, +) { private val cacheDirectory = File(context.cacheDir.toURI()) private val logDirectory = File(context.filesDir.toURI()) @@ -47,11 +55,20 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher val sentSuccessfully = withContext(dispatcher) { + val intentApiOverride = apiEndpointFromIntentHolder.apiEndpointOverride + val apiOverride = + if (BuildConfig.DEBUG && intentApiOverride != null) { + intentApiOverride + } else { + apiEndpointOverride + } + sendProblemReport( userReport.email ?: "", userReport.description, logsPath.absolutePath, cacheDirectory.absolutePath, + apiOverride, ) } @@ -89,5 +106,6 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher userMessage: String, reportPath: String, cacheDirectory: String, + apiEndpointOverride: ApiEndpointOverride?, ): Boolean } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt index 4acf52c7b05b..bc236cc7928b 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt @@ -33,6 +33,7 @@ import net.mullvad.mullvadvpn.repository.UserPreferencesMigration import net.mullvad.mullvadvpn.repository.UserPreferencesRepository import net.mullvad.mullvadvpn.repository.UserPreferencesSerializer import net.mullvad.mullvadvpn.repository.WireguardConstraintsRepository +import net.mullvad.mullvadvpn.service.DaemonConfig import net.mullvad.mullvadvpn.ui.MainActivity import net.mullvad.mullvadvpn.ui.serviceconnection.AppVersionInfoRepository import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager @@ -129,7 +130,7 @@ val uiModule = module { single { ChangelogRepository(get()) } single { UserPreferencesRepository(get()) } single { SettingsRepository(get()) } - single { MullvadProblemReport(get()) } + single { MullvadProblemReport(get(), get().apiEndpointOverride, get()) } single { RelayOverridesRepository(get()) } single { CustomListsRepository(get()) } single { RelayListRepository(get(), get()) } diff --git a/android/lib/endpoint/src/main/kotlin/net/mullvad/mullvadvpn/lib/endpoint/ApiEndpointOverride.kt b/android/lib/endpoint/src/main/kotlin/net/mullvad/mullvadvpn/lib/endpoint/ApiEndpointOverride.kt index 5201e8638678..7350aa0d2634 100644 --- a/android/lib/endpoint/src/main/kotlin/net/mullvad/mullvadvpn/lib/endpoint/ApiEndpointOverride.kt +++ b/android/lib/endpoint/src/main/kotlin/net/mullvad/mullvadvpn/lib/endpoint/ApiEndpointOverride.kt @@ -7,7 +7,6 @@ import kotlinx.parcelize.Parcelize data class ApiEndpointOverride( val hostname: String, val port: Int = CUSTOM_ENDPOINT_HTTPS_PORT, - val disableAddressCache: Boolean = true, val disableTls: Boolean = false, val forceDirectConnection: Boolean = true, ) : Parcelable { diff --git a/android/test/mockapi/src/main/kotlin/net/mullvad/mullvadvpn/test/mockapi/MockApiTest.kt b/android/test/mockapi/src/main/kotlin/net/mullvad/mullvadvpn/test/mockapi/MockApiTest.kt index d3590fc056e3..c7cb8f0377de 100644 --- a/android/test/mockapi/src/main/kotlin/net/mullvad/mullvadvpn/test/mockapi/MockApiTest.kt +++ b/android/test/mockapi/src/main/kotlin/net/mullvad/mullvadvpn/test/mockapi/MockApiTest.kt @@ -55,11 +55,6 @@ abstract class MockApiTest { } private fun createEndpoint(port: Int): ApiEndpointOverride { - return ApiEndpointOverride( - InetAddress.getLocalHost().hostName, - port, - disableAddressCache = true, - disableTls = true, - ) + return ApiEndpointOverride(InetAddress.getLocalHost().hostName, port, disableTls = true) } } diff --git a/ios/MullvadVPNUITests/MullvadApi.swift b/ios/MullvadVPNUITests/MullvadApi.swift index 6f84ac1976b6..18755b641ba2 100644 --- a/ios/MullvadVPNUITests/MullvadApi.swift +++ b/ios/MullvadVPNUITests/MullvadApi.swift @@ -56,7 +56,8 @@ class MullvadApi { let result = mullvad_api_client_initialize( &clientContext, apiAddress, - hostname + hostname, + false ) try ApiError(result).throwIfErr() } diff --git a/mullvad-api/Cargo.toml b/mullvad-api/Cargo.toml index fc9d7d899b37..005d1a950507 100644 --- a/mullvad-api/Cargo.toml +++ b/mullvad-api/Cargo.toml @@ -33,6 +33,7 @@ tokio = { workspace = true, features = ["macros", "time", "rt-multi-thread", "ne tokio-rustls = { version = "0.26.0", features = ["logging", "tls12", "ring"], default-features = false} tokio-socks = "0.5.1" rustls-pemfile = "2.1.3" +uuid = { version = "1.4.1", features = ["v4"] } mullvad-encrypted-dns-proxy = { path = "../mullvad-encrypted-dns-proxy" } mullvad-fs = { path = "../mullvad-fs" } @@ -50,14 +51,6 @@ mockito = "1.6.1" [build-dependencies] cbindgen = { version = "0.24.3", default-features = false } -[target.'cfg(target_os = "ios")'.dependencies] -uuid = { version = "1.4.1", features = ["v4"] } - [lib] crate-type = [ "rlib", "staticlib" ] bench = false - -[[test]] -name = "ffi" -# required-features = [ "api-override" ] -features = [ "api-override" ] diff --git a/mullvad-api/include/mullvad-api.h b/mullvad-api/include/mullvad-api.h index e0295b20aa52..4e5f78aef28b 100644 --- a/mullvad-api/include/mullvad-api.h +++ b/mullvad-api/include/mullvad-api.h @@ -49,15 +49,16 @@ typedef struct MullvadApiDevice { * struct. * * * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address - * representation - * ("143.32.4.32:9090"), the port is mandatory. + * representation ("143.32.4.32:9090"), the port is mandatory. * * * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be * used for TLS validation. + * * `disable_tls`: only valid when built for tests, can be ignored when consumed by Swift. */ struct MullvadApiError mullvad_api_client_initialize(struct MullvadApiClient *client_ptr, const char *api_address_ptr, - const char *hostname); + const char *hostname, + bool disable_tls); /** * Removes all devices from a given account @@ -98,8 +99,8 @@ struct MullvadApiError mullvad_api_get_expiry(struct MullvadApiClient client_ptr * * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the * account that will have all of it's devices removed. * - * * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function - * doesn't return an error, the pointer will be initialized with a valid instance of + * * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't + * return an error, the pointer will be initialized with a valid instance of * `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices. */ struct MullvadApiError mullvad_api_list_devices(struct MullvadApiClient client_ptr, diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index 0898f8da1f96..a6a60146b4fe 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -1,7 +1,6 @@ //! This module keeps track of the last known good API IP address and reads and stores it on disk. -use super::API; -use crate::DnsResolver; +use crate::{ApiEndpoint, DnsResolver}; use async_trait::async_trait; use std::{io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ @@ -38,42 +37,42 @@ impl DnsResolver for AddressCache { #[derive(Clone)] pub struct AddressCache { + hostname: String, inner: Arc>, write_path: Option>, } impl AddressCache { /// Initialize cache using the hardcoded address, and write changes to `write_path`. - pub fn new(write_path: Option>) -> Self { - Self::new_inner(API.address(), write_path) - } - - pub fn with_static_addr(address: SocketAddr) -> Self { - Self::new_inner(address, None) + pub fn new(endpoint: &ApiEndpoint, write_path: Option>) -> Self { + Self::new_inner(endpoint.address(), endpoint.host().to_owned(), write_path) } /// Initialize cache using `read_path`, and write changes to `write_path`. - pub async fn from_file(read_path: &Path, write_path: Option>) -> Result { + pub async fn from_file( + read_path: &Path, + write_path: Option>, + hostname: String, + ) -> Result { log::debug!("Loading API addresses from {}", read_path.display()); - Ok(Self::new_inner( - read_address_file(read_path).await?, - write_path, - )) + let address = read_address_file(read_path).await?; + Ok(Self::new_inner(address, hostname, write_path)) } - fn new_inner(address: SocketAddr, write_path: Option>) -> Self { + fn new_inner(address: SocketAddr, hostname: String, write_path: Option>) -> Self { let cache = AddressCacheInner::from_address(address); log::debug!("Using API address: {}", cache.address); Self { inner: Arc::new(Mutex::new(cache)), write_path: write_path.map(Arc::from), + hostname, } } /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`. async fn resolve_hostname(&self, hostname: &str) -> Option { - if hostname.eq_ignore_ascii_case(API.host()) { + if hostname.eq_ignore_ascii_case(&self.hostname) { Some(self.get_address().await) } else { None diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index def32303eaef..3ea771cc81ee 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,14 +2,18 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; +use mullvad_api::{ + proxy::ApiConnectionMode, rest::Error as RestError, ApiEndpoint, RelayListProxy, +}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) - .expect("Failed to load runtime"); + let runtime = mullvad_api::Runtime::new( + tokio::runtime::Handle::current(), + &ApiEndpoint::from_env_vars(), + ); let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider())) diff --git a/mullvad-api/src/ffi/error.rs b/mullvad-api/src/ffi/error.rs index 539a6c23a052..66ffc0122064 100644 --- a/mullvad-api/src/ffi/error.rs +++ b/mullvad-api/src/ffi/error.rs @@ -13,6 +13,7 @@ pub enum MullvadApiErrorKind { /// MullvadApiErrorKind contains a description and an error kind. If the error kind is /// `MullvadApiErrorKind` is NoError, the pointer will be nil. +#[derive(Debug)] #[repr(C)] pub struct MullvadApiError { description: *mut libc::c_char, @@ -47,6 +48,13 @@ impl MullvadApiError { } } + pub fn unwrap(&self) { + if !matches!(self.kind, MullvadApiErrorKind::NoError) { + let desc = unsafe { std::ffi::CStr::from_ptr(self.description) }; + panic!("API ERROR - {:?} - {}", self.kind, desc.to_str().unwrap()); + } + } + pub fn drop(self) { if self.description.is_null() { return; diff --git a/mullvad-api/src/ffi/mod.rs b/mullvad-api/src/ffi/mod.rs index a68ea40ed6cb..967748825777 100644 --- a/mullvad-api/src/ffi/mod.rs +++ b/mullvad-api/src/ffi/mod.rs @@ -1,3 +1,4 @@ +#![cfg(not(target_os = "android"))] use std::{ ffi::{CStr, CString}, net::SocketAddr, @@ -6,8 +7,9 @@ use std::{ }; use crate::{ + proxy::ApiConnectionMode, rest::{self, MullvadRestHandle}, - AccountsProxy, DevicesProxy, + AccountsProxy, ApiEndpoint, DevicesProxy, }; mod device; @@ -48,13 +50,13 @@ impl MullvadApiClient { struct FfiClient { tokio_runtime: tokio::runtime::Runtime, api_runtime: crate::Runtime, - api_hostname: String, } impl FfiClient { unsafe fn new( api_address_ptr: *const libc::c_char, hostname: *const libc::c_char, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> Result { // SAFETY: addr_str must be a valid pointer to a null-terminated string. let addr_str = unsafe { string_from_raw_ptr(api_address_ptr)? }; @@ -68,12 +70,15 @@ impl FfiClient { ) })?; - // The call site guarantees that - // api_hostname and api_address will never change after the first call to new. - std::env::set_var(crate::env::API_HOST_VAR, &api_hostname); - std::env::set_var(crate::env::API_ADDR_VAR, &addr_str); - std::env::set_var(crate::env::API_FORCE_DIRECT_VAR, "0"); - std::env::set_var(crate::env::DISABLE_TLS_VAR, "0"); + let endpoint = ApiEndpoint { + host: Some(api_hostname.clone()), + address: Some(api_address), + #[cfg(feature = "api-override")] + force_direct: false, + #[cfg(any(feature = "api-override", test))] + disable_tls, + }; + let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); runtime_builder.worker_threads(2).enable_all(); @@ -83,14 +88,12 @@ impl FfiClient { // It is imperative that the REST runtime is created within an async context, otherwise // ApiAvailability panics. - let api_runtime = tokio_runtime.block_on(async { - crate::Runtime::with_static_addr(tokio_runtime.handle().clone(), api_address) - }); + let api_runtime = tokio_runtime + .block_on(async { crate::Runtime::new(tokio_runtime.handle().clone(), &endpoint) }); let context = FfiClient { tokio_runtime, api_runtime, - api_hostname, }; Ok(context) @@ -204,7 +207,7 @@ impl FfiClient { fn rest_handle(&self) -> MullvadRestHandle { self.tokio_handle().block_on(async { self.api_runtime - .static_mullvad_rest_handle(self.api_hostname.clone()) + .mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()) }) } @@ -229,18 +232,31 @@ impl FfiClient { /// struct. /// /// * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address -/// representation -/// ("143.32.4.32:9090"), the port is mandatory. +/// representation ("143.32.4.32:9090"), the port is mandatory. /// /// * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be /// used for TLS validation. +/// * `disable_tls`: only valid when built for tests, can be ignored when consumed by Swift. #[no_mangle] pub unsafe extern "C" fn mullvad_api_client_initialize( client_ptr: *mut MullvadApiClient, api_address_ptr: *const libc::c_char, hostname: *const libc::c_char, + disable_tls: bool, ) -> MullvadApiError { - match unsafe { FfiClient::new(api_address_ptr, hostname) } { + #[cfg(not(any(feature = "api-override", test)))] + if disable_tls { + log::error!("disable_tls has no effect when mullvad-api is built without api-override"); + } + + match unsafe { + FfiClient::new( + api_address_ptr, + hostname, + #[cfg(any(feature = "api-override", test))] + disable_tls, + ) + } { Ok(client) => { unsafe { std::ptr::write(client_ptr, MullvadApiClient::new(client)); @@ -306,8 +322,8 @@ pub unsafe extern "C" fn mullvad_api_get_expiry( /// * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the /// account that will have all of it's devices removed. /// -/// * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function -/// doesn't return an error, the pointer will be initialized with a valid instance of +/// * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't +/// return an error, the pointer will be initialized with a valid instance of /// `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices. #[no_mangle] pub unsafe extern "C" fn mullvad_api_list_devices( @@ -443,3 +459,57 @@ unsafe fn string_from_raw_ptr(ptr: *const libc::c_char) -> Result MullvadApiClient { + let mut client = MaybeUninit::::uninit(); + let cstr_address = CString::new(addr.to_string()).unwrap(); + unsafe { + mullvad_api_client_initialize( + client.as_mut_ptr(), + cstr_address.as_ptr().cast(), + STAGING_HOSTNAME.as_ptr().cast(), + true, + ) + .unwrap(); + }; + unsafe { client.assume_init() } + } + + #[test] + fn test_create_delete_account() { + let server = test_server(); + let client = create_client(&server.socket_address()); + + let mut account_buf = vec![0 as libc::c_char; 100]; + unsafe { mullvad_api_create_account(client, account_buf.as_mut_ptr().cast()).unwrap() }; + } + + fn test_server() -> ServerGuard { + let mut server = Server::new(); + let expected_create_account_response = br#"{"id":"085df870-0fc2-47cb-9e8c-cb43c1bdaac0","expiry":"2024-12-11T12:56:32+00:00","max_ports":0,"can_add_ports":false,"max_devices":5,"can_add_devices":true,"number":"6705749539195318"}"#; + server + .mock( + "POST", + &*("/".to_string() + crate::ACCOUNTS_URL_PREFIX + "/accounts"), + ) + .with_header("content-type", "application/json") + .with_status(201) + .with_body(expected_create_account_response) + .create(); + + server + } +} diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 3dfd168a9281..f86c538a672a 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -41,8 +41,8 @@ use tokio::{ }; use tower::Service; -#[cfg(feature = "api-override")] -use crate::{proxy::ConnectionDecorator, API}; +#[cfg(any(feature = "api-override", test))] +use crate::proxy::ConnectionDecorator; const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); @@ -89,6 +89,7 @@ impl InnerConnectionMode { hostname: &str, addr: &SocketAddr, #[cfg(target_os = "android")] socket_bypass_tx: Option>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> Result { match self { // Set up a TCP-socket connection. @@ -101,6 +102,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -121,6 +124,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -153,6 +158,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -168,6 +175,8 @@ impl InnerConnectionMode { make_proxy_stream, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) .await } @@ -191,6 +200,7 @@ impl InnerConnectionMode { hostname: &str, make_proxy_stream: ProxyFactory, #[cfg(target_os = "android")] socket_bypass_tx: Option>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> Result where ProxyFactory: FnOnce(TcpStream) -> ProxyFuture, @@ -206,8 +216,8 @@ impl InnerConnectionMode { let proxy = make_proxy_stream(socket).await?; - #[cfg(feature = "api-override")] - if API.disable_tls { + #[cfg(any(feature = "api-override", test))] + if disable_tls { return Ok(ApiConnection::new(Box::new(ConnectionDecorator(proxy)))); } @@ -290,6 +300,8 @@ pub struct HttpsConnectorWithSni { dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, + #[cfg(any(feature = "api-override", test))] + disable_tls: bool, } struct HttpsConnectorWithSniInner { @@ -304,6 +316,7 @@ impl HttpsConnectorWithSni { pub fn new( dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> (Self, HttpsConnectorWithSniHandle) { let (tx, mut rx) = mpsc::unbounded(); let abort_notify = Arc::new(tokio::sync::Notify::new()); @@ -352,6 +365,8 @@ impl HttpsConnectorWithSni { dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, }, HttpsConnectorWithSniHandle { tx }, ) @@ -435,6 +450,9 @@ impl Service for HttpsConnectorWithSni { let socket_bypass_tx = self.socket_bypass_tx.clone(); let dns_resolver = self.dns_resolver.clone(); + #[cfg(any(feature = "api-override", test))] + let disable_tls = self.disable_tls; + let fut = async move { if uri.scheme() != Some(&Scheme::HTTPS) { return Err(io::Error::new( @@ -460,6 +478,8 @@ impl Service for HttpsConnectorWithSni { &addr, #[cfg(target_os = "android")] socket_bypass_tx.clone(), + #[cfg(any(feature = "api-override", test))] + disable_tls, ); pin_mut!(stream_fut); diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 3b02e4fe98ed..a47c708b2ef4 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -10,14 +10,12 @@ use mullvad_types::{ }; use proxy::{ApiConnectionMode, ConnectionModeProvider}; use std::{ - cell::Cell, collections::BTreeMap, future::Future, io, net::{IpAddr, Ipv4Addr, SocketAddr}, - ops::Deref, path::Path, - sync::{Arc, OnceLock}, + sync::Arc, }; use talpid_types::ErrorExt; @@ -37,7 +35,6 @@ mod address_cache; pub mod device; mod relay_list; -#[cfg(target_os = "ios")] pub mod ffi; pub use address_cache::AddressCache; @@ -70,41 +67,6 @@ const APP_URL_PREFIX: &str = "app/v1"; #[cfg(target_os = "android")] const GOOGLE_PAYMENTS_URL_PREFIX: &str = "payments/google-play/v1"; -pub static API: LazyManual = LazyManual::new(ApiEndpoint::from_env_vars); - -unsafe impl Sync for LazyManual where OnceLock: Sync {} - -/// A value that is either initialized on access or explicitly. -pub struct LazyManual T> { - cell: OnceLock, - lazy_fn: Cell>, -} - -impl LazyManual { - const fn new(lazy_fn: F) -> Self { - Self { - cell: OnceLock::new(), - lazy_fn: Cell::new(Some(lazy_fn)), - } - } - - /// Tries to initialize the object. An error is returned if it is - /// already initialized. - #[cfg(feature = "api-override")] - pub fn override_init(&self, val: T) -> Result<(), T> { - let _ = self.lazy_fn.take(); - self.cell.set(val) - } -} - -impl Deref for LazyManual { - type Target = T; - - fn deref(&self) -> &Self::Target { - self.cell.get_or_init(|| (self.lazy_fn.take().unwrap())()) - } -} - pub mod env { pub const API_HOST_VAR: &str = "MULLVAD_API_HOST"; pub const API_ADDR_VAR: &str = "MULLVAD_API_ADDR"; @@ -113,7 +75,7 @@ pub mod env { } /// A hostname and socketaddr to reach the Mullvad REST API over. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ApiEndpoint { /// An overriden API hostname. Initialized with the value of the environment /// variable `MULLVAD_API_HOST` if it has been set. @@ -132,9 +94,7 @@ pub struct ApiEndpoint { /// If [`Self::address`] is populated with [`Some(SocketAddr)`], it should /// always be respected when establishing API connections. pub address: Option, - #[cfg(feature = "api-override")] - pub disable_address_cache: bool, - #[cfg(feature = "api-override")] + #[cfg(any(feature = "api-override", test))] pub disable_tls: bool, #[cfg(feature = "api-override")] /// Whether bridges/proxies can be used to access the API or not. This is @@ -175,7 +135,6 @@ impl ApiEndpoint { let mut api = ApiEndpoint { host: None, address: None, - disable_address_cache: host_var.is_some() || address_var.is_some(), disable_tls: false, force_direct: force_direct .map(|force_direct| force_direct != "0") @@ -244,6 +203,11 @@ impl ApiEndpoint { api } + #[cfg(feature = "api-override")] + pub fn should_disable_address_cache(&self) -> bool { + self.host.is_some() || self.address.is_some() + } + /// Returns the endpoint to connect to the API over. /// /// # Panics @@ -269,9 +233,31 @@ impl ApiEndpoint { ApiEndpoint { host: None, address: None, + #[cfg(test)] + disable_tls: false, } } + /// Returns a new API endpoint with the given host and socket address. + pub fn new( + host: String, + address: SocketAddr, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, + ) -> Self { + Self { + host: Some(host), + address: Some(address), + #[cfg(any(feature = "api-override", test))] + disable_tls, + #[cfg(feature = "api-override")] + force_direct: false, + } + } + + pub fn set_addr(&mut self, address: SocketAddr) { + self.address = Some(address); + } + /// Read the [`Self::host`] value, falling back to /// [`Self::API_HOST_DEFAULT`] as default value if it does not exist. pub fn host(&self) -> &str { @@ -342,6 +328,7 @@ pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, + endpoint: ApiEndpoint, #[cfg(target_os = "android")] socket_bypass_tx: Option>, } @@ -362,40 +349,27 @@ pub enum Error { } impl Runtime { - /// Create a new `Runtime`. - pub fn new(handle: tokio::runtime::Handle) -> Result { - Self::new_inner( - handle, - #[cfg(target_os = "android")] - None, - ) - } - - #[cfg(target_os = "ios")] - pub fn with_static_addr(handle: tokio::runtime::Handle, address: SocketAddr) -> Self { - Runtime { - handle, - address_cache: AddressCache::with_static_addr(address), - api_availability: ApiAvailability::default(), - } - } - - fn new_inner( + /// Will create a new Runtime without a cache with the provided API endpoint. + pub fn new( handle: tokio::runtime::Handle, + endpoint: &ApiEndpoint, #[cfg(target_os = "android")] socket_bypass_tx: Option>, - ) -> Result { - Ok(Runtime { + ) -> Self { + Runtime { handle, - address_cache: AddressCache::new(None), + address_cache: AddressCache::new(endpoint, None), api_availability: ApiAvailability::default(), + endpoint: endpoint.clone(), #[cfg(target_os = "android")] socket_bypass_tx, - }) + } } /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. + /// Will try to construct an API endpoint from the environment. pub async fn with_cache( + endpoint: &ApiEndpoint, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option>, @@ -403,12 +377,13 @@ impl Runtime { let handle = tokio::runtime::Handle::current(); #[cfg(feature = "api-override")] - if API.disable_address_cache { - return Self::new_inner( + if endpoint.should_disable_address_cache() { + return Ok(Self::new( handle, + endpoint, #[cfg(target_os = "android")] socket_bypass_tx, - ); + )); } let cache_file = cache_dir.join(API_IP_CACHE_FILENAME); @@ -418,7 +393,13 @@ impl Runtime { None }; - let address_cache = match AddressCache::from_file(&cache_file, write_file.clone()).await { + let address_cache = match AddressCache::from_file( + &cache_file, + write_file.clone(), + endpoint.host().to_owned(), + ) + .await + { Ok(cache) => cache, Err(error) => { if cache_file.exists() { @@ -429,7 +410,7 @@ impl Runtime { ) ); } - AddressCache::new(write_file) + AddressCache::new(endpoint, write_file) } }; @@ -439,12 +420,14 @@ impl Runtime { handle, address_cache, api_availability, + endpoint: endpoint.clone(), #[cfg(target_os = "android")] socket_bypass_tx, }) } - /// Returns a request factory initialized to create requests for the master API + /// Returns a request factory initialized to create requests for the master API Assumes an API + /// endpoint that is constructed from env vars, or uses default values. pub fn mullvad_rest_handle( &self, connection_mode_provider: T, @@ -454,21 +437,10 @@ impl Runtime { Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), + #[cfg(any(feature = "api-override", test))] + self.endpoint.disable_tls, ); - let token_store = access::AccessTokenStore::new(service.clone(), API.host()); - let factory = rest::RequestFactory::new(API.host().to_owned(), Some(token_store)); - - rest::MullvadRestHandle::new(service, factory, self.availability_handle()) - } - - /// This is only to be used in test code - pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle { - let service = self.new_request_service( - ApiConnectionMode::Direct.into_provider(), - Arc::new(self.address_cache.clone()), - #[cfg(target_os = "android")] - self.socket_bypass_tx.clone(), - ); + let hostname = self.endpoint.host().to_owned(); let token_store = access::AccessTokenStore::new(service.clone(), hostname.clone()); let factory = rest::RequestFactory::new(hostname, Some(token_store)); @@ -482,6 +454,8 @@ impl Runtime { Arc::new(dns_resolver), #[cfg(target_os = "android")] None, + #[cfg(any(feature = "api-override", test))] + false, ) } @@ -491,6 +465,7 @@ impl Runtime { connection_mode_provider: T, dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> rest::RequestServiceHandle { rest::RequestService::spawn( self.api_availability.clone(), @@ -498,6 +473,8 @@ impl Runtime { dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, + #[cfg(any(feature = "api-override", test))] + disable_tls, ) } @@ -582,7 +559,6 @@ impl AccountsProxy { } } - #[cfg(target_os = "ios")] pub fn delete_account( &self, account: AccountNumber, diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 5b93eea31142..cab3bb7e0f17 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -154,11 +154,14 @@ impl RequestService { connection_mode_provider: T, dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, + #[cfg(any(feature = "api-override", test))] disable_tls: bool, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx.clone(), + #[cfg(any(feature = "api-override", test))] + disable_tls, ); connector_handle.set_connection_mode(connection_mode_provider.initial()); @@ -461,7 +464,6 @@ where } // Parse unexpected responses and errors - let response = response?; if !self.expected_status.contains(&response.status()) { diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index db952941ebbe..97799d244e76 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -10,6 +10,8 @@ use futures::{ channel::{mpsc, oneshot}, StreamExt, }; +#[cfg(feature = "api-override")] +use mullvad_api::ApiEndpoint; use mullvad_api::{ availability::ApiAvailability, proxy::{ApiConnectionMode, ConnectionModeProvider, ProxyConfig}, @@ -250,6 +252,8 @@ impl ConnectionModeProvider for AccessModeConnectionModeProvider { /// or via any supported custom proxy protocol /// ([`talpid_types::net::proxy::CustomProxy`]). pub struct AccessModeSelector { + #[cfg(feature = "api-override")] + api_endpoint: ApiEndpoint, cmd_rx: mpsc::UnboundedReceiver, cache_dir: PathBuf, /// Used for selecting a Bridge when the `Mullvad Bridges` access method is used. @@ -271,6 +275,7 @@ impl AccessModeSelector { relay_selector: RelaySelector, #[cfg_attr(not(feature = "api-override"), allow(unused_mut))] mut access_method_settings: Settings, + #[cfg(feature = "api-override")] api_endpoint: ApiEndpoint, access_method_event_sender: DaemonEventSender<(AccessMethodEvent, oneshot::Sender<()>)>, address_cache: AddressCache, ) -> Result<(AccessModeSelectorHandle, AccessModeConnectionModeProvider)> { @@ -278,7 +283,7 @@ impl AccessModeSelector { #[cfg(feature = "api-override")] { - if mullvad_api::API.force_direct { + if api_endpoint.force_direct { access_method_settings .update(|setting| setting.is_direct(), |setting| setting.enable()); } @@ -312,6 +317,8 @@ impl AccessModeSelector { connection_mode_provider_sender: change_tx, current: initial_connection_mode, index, + #[cfg(feature = "api-override")] + api_endpoint, }; tokio::spawn(selector.into_future()); @@ -365,7 +372,7 @@ impl AccessModeSelector { async fn use_access_method(&mut self, id: Id) { #[cfg(feature = "api-override")] { - if mullvad_api::API.force_direct { + if self.api_endpoint.force_direct { log::debug!("API proxies are disabled"); return; } @@ -392,7 +399,7 @@ impl AccessModeSelector { async fn next_connection_mode(&mut self) -> Result { #[cfg(feature = "api-override")] { - if mullvad_api::API.force_direct { + if self.api_endpoint.force_direct { log::debug!("API proxies are disabled"); return Ok(ApiConnectionMode::Direct); } diff --git a/mullvad-daemon/src/api_address_updater.rs b/mullvad-daemon/src/api_address_updater.rs index f85d6259b803..de347daccbb7 100644 --- a/mullvad-daemon/src/api_address_updater.rs +++ b/mullvad-daemon/src/api_address_updater.rs @@ -1,5 +1,7 @@ //! A small updater that keeps the API IP address cache up to date by fetching changes from the //! Mullvad API. +#[cfg(feature = "api-override")] +use mullvad_api::ApiEndpoint; use mullvad_api::{rest::MullvadRestHandle, AddressCache, ApiProxy}; use std::time::Duration; @@ -7,9 +9,13 @@ const API_IP_CHECK_INITIAL: Duration = Duration::from_secs(15 * 60); const API_IP_CHECK_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); const API_IP_CHECK_ERROR_INTERVAL: Duration = Duration::from_secs(15 * 60); -pub async fn run_api_address_fetcher(address_cache: AddressCache, handle: MullvadRestHandle) { +pub async fn run_api_address_fetcher( + address_cache: AddressCache, + handle: MullvadRestHandle, + #[cfg(feature = "api-override")] endpoint: ApiEndpoint, +) { #[cfg(feature = "api-override")] - if mullvad_api::API.disable_address_cache { + if endpoint.should_disable_address_cache() { return; } diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 4f98c73d0189..e155ba792247 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -39,6 +39,7 @@ use futures::{ }; use geoip::GeoIpHandler; use management_interface::ManagementInterfaceServer; +use mullvad_api::ApiEndpoint; use mullvad_relay_selector::{RelaySelector, SelectorConfig}; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; @@ -596,6 +597,7 @@ impl Daemon { cache_dir: PathBuf, rpc_socket_path: PathBuf, daemon_command_channel: DaemonCommandChannel, + endpoint: ApiEndpoint, #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result { #[cfg(target_os = "macos")] @@ -620,6 +622,7 @@ impl Daemon { mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; let api_runtime = mullvad_api::Runtime::with_cache( + &endpoint, &cache_dir, true, #[cfg(target_os = "android")] @@ -667,6 +670,8 @@ impl Daemon { cache_dir.clone(), relay_selector.clone(), settings.api_access_methods.clone(), + #[cfg(feature = "api-override")] + endpoint.clone(), internal_event_tx.to_specialized_sender(), api_runtime.address_cache().clone(), ) @@ -679,6 +684,8 @@ impl Daemon { tokio::spawn(api_address_updater::run_api_address_fetcher( api_runtime.address_cache().clone(), api_handle.clone(), + #[cfg(feature = "api-override")] + endpoint, )); let access_method_handle = access_mode_handler.clone(); diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs index 3569646ccfbf..98910f76f864 100644 --- a/mullvad-daemon/src/main.rs +++ b/mullvad-daemon/src/main.rs @@ -1,5 +1,6 @@ use std::{path::PathBuf, thread, time::Duration}; +use mullvad_api::ApiEndpoint; #[cfg(not(windows))] use mullvad_daemon::cleanup_old_rpc_socket; use mullvad_daemon::{ @@ -221,6 +222,7 @@ async fn create_daemon(log_dir: Option) -> Result { cache_dir, rpc_socket_path, daemon_command_channel, + ApiEndpoint::from_env_vars(), ) .await .map_err(|e| e.display_chain_with_msg("Unable to initialize daemon")) diff --git a/mullvad-jni/src/api.rs b/mullvad-jni/src/api.rs index b7dd5d674787..81043871f842 100644 --- a/mullvad-jni/src/api.rs +++ b/mullvad-jni/src/api.rs @@ -20,7 +20,6 @@ pub fn api_endpoint_from_java( Some(mullvad_api::ApiEndpoint { host: Some(hostname), address, - disable_address_cache: disable_address_cache_from_java(env, endpoint_override), disable_tls: disable_tls_from_java(env, endpoint_override), force_direct: force_direct_from_java(env, endpoint_override), }) @@ -70,17 +69,9 @@ fn port_from_java(env: &JnixEnv<'_>, endpoint_override: JObject<'_>) -> u16 { u16::try_from(port).expect("invalid port") } -#[cfg(feature = "api-override")] -fn disable_address_cache_from_java(env: &JnixEnv<'_>, endpoint_override: JObject<'_>) -> bool { - env.call_method(endpoint_override, "component3", "()Z", &[]) - .expect("missing ApiEndpointOverride.disableAddressCache") - .z() - .expect("ApiEndpointOverride.disableAddressCache is not a bool") -} - #[cfg(feature = "api-override")] fn disable_tls_from_java(env: &JnixEnv<'_>, endpoint_override: JObject<'_>) -> bool { - env.call_method(endpoint_override, "component4", "()Z", &[]) + env.call_method(endpoint_override, "component3", "()Z", &[]) .expect("missing ApiEndpointOverride.disableTls") .z() .expect("ApiEndpointOverride.disableTls is not a bool") @@ -88,7 +79,7 @@ fn disable_tls_from_java(env: &JnixEnv<'_>, endpoint_override: JObject<'_>) -> b #[cfg(feature = "api-override")] fn force_direct_from_java(env: &JnixEnv<'_>, endpoint_override: JObject<'_>) -> bool { - env.call_method(endpoint_override, "component5", "()Z", &[]) + env.call_method(endpoint_override, "component4", "()Z", &[]) .expect("missing ApiEndpointOverride.forceDirectConnection") .z() .expect("ApiEndpointOverride.forceDirectConnection is not a bool") diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 1dd7c8694263..2d2561403f00 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -12,6 +12,7 @@ use jnix::{ }, FromJava, JnixEnv, }; +use mullvad_api::ApiEndpoint; use mullvad_daemon::{ cleanup_old_rpc_socket, exception_logging, logging, runtime::new_multi_thread, version, Daemon, DaemonCommandChannel, DaemonCommandSender, @@ -150,7 +151,13 @@ fn start( log::warn!("api_endpoint will be ignored since 'api-override' is not enabled"); } - spawn_daemon(android_context, rpc_socket, files_dir, cache_dir) + spawn_daemon( + android_context, + rpc_socket, + files_dir, + cache_dir, + api_endpoint.unwrap_or(ApiEndpoint::from_env_vars()), + ) } fn spawn_daemon( @@ -158,6 +165,7 @@ fn spawn_daemon( rpc_socket: PathBuf, files_dir: PathBuf, cache_dir: PathBuf, + endpoint: ApiEndpoint, ) -> Result { let daemon_command_channel = DaemonCommandChannel::new(); let daemon_command_tx = daemon_command_channel.sender(); @@ -170,6 +178,7 @@ fn spawn_daemon( cache_dir, daemon_command_channel, android_context, + endpoint, ))?; Ok(DaemonContext { @@ -185,6 +194,7 @@ async fn spawn_daemon_inner( cache_dir: PathBuf, daemon_command_channel: DaemonCommandChannel, android_context: AndroidContext, + endpoint: ApiEndpoint, ) -> Result, Error> { cleanup_old_rpc_socket(&rpc_socket).await; @@ -195,6 +205,7 @@ async fn spawn_daemon_inner( cache_dir, rpc_socket, daemon_command_channel, + endpoint, android_context, ) .await diff --git a/mullvad-jni/src/problem_report.rs b/mullvad-jni/src/problem_report.rs index 9943ec4c59ba..dc3693cfff38 100644 --- a/mullvad-jni/src/problem_report.rs +++ b/mullvad-jni/src/problem_report.rs @@ -6,6 +6,7 @@ use jnix::{ }, FromJava, JnixEnv, }; +use mullvad_api::ApiEndpoint; use std::path::Path; use talpid_types::ErrorExt; @@ -44,6 +45,7 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_dataproxy_MullvadProblemRepor userMessage: JString<'_>, outputPath: JString<'_>, cacheDirectory: JString<'_>, + endpoint: JObject<'_>, ) -> jboolean { let env = JnixEnv::from(env); let user_email = String::from_java(&env, userEmail); @@ -52,12 +54,15 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_dataproxy_MullvadProblemRepor let output_path = Path::new(&output_path_string); let cache_directory_string = String::from_java(&env, cacheDirectory); let cache_directory = Path::new(&cache_directory_string); + let api_endpoint = + crate::api::api_endpoint_from_java(&env, endpoint).unwrap_or(ApiEndpoint::from_env_vars()); let send_result = mullvad_problem_report::send_problem_report( &user_email, &user_message, output_path, cache_directory, + api_endpoint, ); match send_result { diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 270de55f9589..b2a4e22166fc 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -1,4 +1,4 @@ -use mullvad_api::proxy::ApiConnectionMode; +use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint}; use regex::Regex; use std::{ borrow::Cow, @@ -261,6 +261,7 @@ pub fn send_problem_report( user_message: &str, report_path: &Path, cache_dir: &Path, + endpoint: ApiEndpoint, ) -> Result<(), Error> { let report_content = normalize_newlines( read_file_lossy(report_path, REPORT_MAX_SIZE).map_err(|source| { @@ -281,6 +282,7 @@ pub fn send_problem_report( user_message, &report_content, cache_dir, + &endpoint, )) } @@ -289,9 +291,11 @@ async fn send_problem_report_inner( user_message: &str, report_content: &str, cache_dir: &Path, + endpoint: &ApiEndpoint, ) -> Result<(), Error> { let metadata = ProblemReport::parse_metadata(report_content).unwrap_or_else(metadata::collect); let api_runtime = mullvad_api::Runtime::with_cache( + endpoint, cache_dir, false, #[cfg(target_os = "android")] diff --git a/mullvad-problem-report/src/main.rs b/mullvad-problem-report/src/main.rs index bc3e680cdffe..5a073098a3a6 100644 --- a/mullvad-problem-report/src/main.rs +++ b/mullvad-problem-report/src/main.rs @@ -1,4 +1,5 @@ use clap::Parser; +use mullvad_api::ApiEndpoint; use mullvad_problem_report::{collect_report, Error}; use std::{ env, @@ -89,10 +90,16 @@ fn send_problem_report( report_path: &Path, ) -> Result<(), Error> { let cache_dir = mullvad_paths::get_cache_dir().map_err(Error::ObtainCacheDirectory)?; - mullvad_problem_report::send_problem_report(user_email, user_message, report_path, &cache_dir) - .inspect_err(|error| { - eprintln!("{}", error.display_chain()); - })?; + mullvad_problem_report::send_problem_report( + user_email, + user_message, + report_path, + &cache_dir, + ApiEndpoint::from_env_vars(), + ) + .inspect_err(|error| { + eprintln!("{}", error.display_chain()); + })?; println!("Problem report sent"); Ok(()) diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index d3dfd6de8ac4..e525d5cb888d 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -1,7 +1,7 @@ use clap::Parser; use std::{path::PathBuf, process, str::FromStr, sync::LazyLock, time::Duration}; -use mullvad_api::{proxy::ApiConnectionMode, DEVICE_NOT_FOUND}; +use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint, DEVICE_NOT_FOUND}; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::version::ParsedAppVersion; use talpid_core::firewall::{self, Firewall}; @@ -152,9 +152,10 @@ async fn remove_device() -> Result<(), Error> { .await .map_err(Error::ReadDeviceCacheError)?; if let Some(device) = state.into_device() { - let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false) - .await - .map_err(Error::RpcInitializationError)?; + let api_runtime = + mullvad_api::Runtime::with_cache(&ApiEndpoint::from_env_vars(), &cache_path, false) + .await + .map_err(Error::RpcInitializationError)?; let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await; let proxy = mullvad_api::DevicesProxy::new( diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs index 7fe14ae58ee1..29227cc82bed 100644 --- a/test/test-manager/src/tests/account.rs +++ b/test/test-manager/src/tests/account.rs @@ -278,9 +278,8 @@ pub async fn clear_devices(device_client: &DevicesProxy) -> anyhow::Result<()> { } pub async fn new_device_client() -> anyhow::Result { - use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint, API}; + use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint}; - let api_endpoint = ApiEndpoint::from_env_vars(); let api_host = format!("api.{}", TEST_CONFIG.mullvad_host); let api_host_with_port = format!("{api_host}:443"); @@ -289,14 +288,10 @@ pub async fn new_device_client() -> anyhow::Result { .context("failed to resolve API host")?; // Override the API endpoint to use the one specified in the test config - let _ = API.override_init(ApiEndpoint { - host: Some(api_host), - address: Some(api_address), - ..api_endpoint - }); - - let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) - .expect("failed to create api runtime"); + let endpoint = ApiEndpoint::new(api_host, api_address, false); + + let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), &endpoint); + let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()); Ok(DevicesProxy::new(rest_handle)) } From 732397b7abe1f29c5b53d22a5fd5876817b4855d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Em=C4=ABls?= Date: Wed, 18 Dec 2024 18:01:18 +0100 Subject: [PATCH 3/3] Change how daemon is started --- mullvad-daemon/src/lib.rs | 60 ++++++++++++++++++++------------------ mullvad-daemon/src/main.rs | 21 +++++++------ mullvad-jni/src/lib.rs | 46 ++++++++++------------------- 3 files changed, 57 insertions(+), 70 deletions(-) diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index e155ba792247..b313e274bc2c 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -588,30 +588,34 @@ pub struct Daemon { volume_update_tx: mpsc::UnboundedSender<()>, location_handler: GeoIpHandler, } +pub struct DaemonConfig { + pub log_dir: Option, + pub resource_dir: PathBuf, + pub settings_dir: PathBuf, + pub cache_dir: PathBuf, + pub rpc_socket_path: PathBuf, + pub endpoint: ApiEndpoint, + #[cfg(target_os = "android")] + pub android_context: AndroidContext, +} impl Daemon { pub async fn start( - log_dir: Option, - resource_dir: PathBuf, - settings_dir: PathBuf, - cache_dir: PathBuf, - rpc_socket_path: PathBuf, + config: DaemonConfig, daemon_command_channel: DaemonCommandChannel, - endpoint: ApiEndpoint, - #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result { #[cfg(target_os = "macos")] macos::bump_filehandle_limit(); let command_sender = daemon_command_channel.sender(); let management_interface = - ManagementInterfaceServer::start(command_sender, rpc_socket_path) + ManagementInterfaceServer::start(command_sender, config.rpc_socket_path) .map_err(Error::ManagementInterfaceError)?; let (internal_event_tx, internal_event_rx) = daemon_command_channel.destructure(); #[cfg(target_os = "android")] - let connectivity_listener = ConnectivityListener::new(android_context.clone()) + let connectivity_listener = ConnectivityListener::new(config.android_context.clone()) .inspect_err(|error| { log::error!( "{}", @@ -620,10 +624,10 @@ impl Daemon { }) .map_err(|_| Error::DaemonUnavailable)?; - mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; + mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&config.cache_dir).await; let api_runtime = mullvad_api::Runtime::with_cache( - &endpoint, - &cache_dir, + &config.endpoint, + &config.cache_dir, true, #[cfg(target_os = "android")] api::create_bypass_tx(&internal_event_tx), @@ -634,7 +638,7 @@ impl Daemon { let api_availability = api_runtime.availability_handle(); api_availability.suspend(); - let migration_data = migrations::migrate_all(&cache_dir, &settings_dir) + let migration_data = migrations::migrate_all(&config.cache_dir, &config.settings_dir) .await .unwrap_or_else(|error| { log::error!( @@ -645,7 +649,7 @@ impl Daemon { }); let settings_event_listener = management_interface.notifier().clone(); - let mut settings = SettingsPersister::load(&settings_dir).await; + let mut settings = SettingsPersister::load(&config.settings_dir).await; settings.register_change_listener(move |settings| { // Notify management interface server of changes to the settings settings_event_listener.notify_settings(settings.to_owned()); @@ -654,8 +658,8 @@ impl Daemon { let initial_selector_config = SelectorConfig::from_settings(&settings); let relay_selector = RelaySelector::new( initial_selector_config, - resource_dir.join(RELAYS_FILENAME), - cache_dir.join(RELAYS_FILENAME), + config.resource_dir.join(RELAYS_FILENAME), + config.cache_dir.join(RELAYS_FILENAME), ); let settings_relay_selector = relay_selector.clone(); @@ -667,11 +671,11 @@ impl Daemon { }); let (access_mode_handler, access_mode_provider) = api::AccessModeSelector::spawn( - cache_dir.clone(), + config.cache_dir.clone(), relay_selector.clone(), settings.api_access_methods.clone(), #[cfg(feature = "api-override")] - endpoint.clone(), + config.endpoint.clone(), internal_event_tx.to_specialized_sender(), api_runtime.address_cache().clone(), ) @@ -685,7 +689,7 @@ impl Daemon { api_runtime.address_cache().clone(), api_handle.clone(), #[cfg(feature = "api-override")] - endpoint, + config.endpoint.clone(), )); let access_method_handle = access_mode_handler.clone(); @@ -709,7 +713,7 @@ impl Daemon { let (account_manager, data) = device::AccountManager::spawn( api_handle.clone(), - &settings_dir, + &config.settings_dir, settings .tunnel_options .wireguard @@ -721,7 +725,7 @@ impl Daemon { .map_err(Error::LoadAccountManager)?; let account_history = account_history::AccountHistory::new( - &settings_dir, + &config.settings_dir, data.device().map(|device| device.account_number.clone()), ) .await @@ -729,9 +733,9 @@ impl Daemon { let target_state = if settings.auto_connect { log::info!("Automatically connecting since auto-connect is turned on"); - PersistentTargetState::new_secured(&cache_dir).await + PersistentTargetState::new_secured(&config.cache_dir).await } else { - PersistentTargetState::new(&cache_dir).await + PersistentTargetState::new(&config.cache_dir).await }; #[cfg(any(windows, target_os = "android", target_os = "macos"))] @@ -790,14 +794,14 @@ impl Daemon { exclude_paths, }, parameters_generator.clone(), - log_dir, - resource_dir.clone(), + config.log_dir, + config.resource_dir.clone(), internal_event_tx.to_specialized_sender(), offline_state_tx, #[cfg(target_os = "windows")] volume_update_rx, #[cfg(target_os = "android")] - android_context, + config.android_context, #[cfg(target_os = "android")] connectivity_listener.clone(), #[cfg(target_os = "linux")] @@ -819,14 +823,14 @@ impl Daemon { let mut relay_list_updater = RelayListUpdater::spawn( relay_selector.clone(), api_handle.clone(), - &cache_dir, + &config.cache_dir, on_relay_list_update, ); let version_updater_handle = version_check::VersionUpdater::spawn( api_handle.clone(), api_availability.clone(), - cache_dir.clone(), + config.cache_dir.clone(), internal_event_tx.to_specialized_sender(), settings.show_beta_releases, ) diff --git a/mullvad-daemon/src/main.rs b/mullvad-daemon/src/main.rs index 98910f76f864..c2146752c7d5 100644 --- a/mullvad-daemon/src/main.rs +++ b/mullvad-daemon/src/main.rs @@ -1,11 +1,10 @@ use std::{path::PathBuf, thread, time::Duration}; -use mullvad_api::ApiEndpoint; #[cfg(not(windows))] use mullvad_daemon::cleanup_old_rpc_socket; use mullvad_daemon::{ exception_logging, logging, rpc_uniqueness_check, runtime, version, Daemon, - DaemonCommandChannel, + DaemonCommandChannel, DaemonConfig, }; use talpid_types::ErrorExt; @@ -213,16 +212,16 @@ async fn create_daemon(log_dir: Option) -> Result { let cache_dir = mullvad_paths::cache_dir() .map_err(|e| e.display_chain_with_msg("Unable to get cache dir"))?; - let daemon_command_channel = DaemonCommandChannel::new(); - Daemon::start( - log_dir, - resource_dir, - settings_dir, - cache_dir, - rpc_socket_path, - daemon_command_channel, - ApiEndpoint::from_env_vars(), + DaemonConfig { + log_dir, + resource_dir, + settings_dir, + cache_dir, + rpc_socket_path, + endpoint: mullvad_api::ApiEndpoint::from_env_vars(), + }, + DaemonCommandChannel::new(), ) .await .map_err(|e| e.display_chain_with_msg("Unable to initialize daemon")) diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 2d2561403f00..fd35396fd00f 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -15,7 +15,7 @@ use jnix::{ use mullvad_api::ApiEndpoint; use mullvad_daemon::{ cleanup_old_rpc_socket, exception_logging, logging, runtime::new_multi_thread, version, Daemon, - DaemonCommandChannel, DaemonCommandSender, + DaemonCommandChannel, DaemonCommandSender, DaemonConfig, }; use std::{ io, @@ -139,13 +139,6 @@ fn start( start_logging(&files_dir).map_err(Error::InitializeLogging)?; version::log_version(); - #[cfg(feature = "api-override")] - if let Some(api_endpoint) = api_endpoint { - log::debug!("Overriding API endpoint: {api_endpoint:?}"); - if mullvad_api::API.override_init(api_endpoint).is_err() { - log::warn!("Ignoring API settings (already initialized)"); - } - } #[cfg(not(feature = "api-override"))] if api_endpoint.is_some() { log::warn!("api_endpoint will be ignored since 'api-override' is not enabled"); @@ -172,14 +165,18 @@ fn spawn_daemon( let runtime = new_multi_thread().build().map_err(Error::InitTokio)?; - let running_daemon = runtime.block_on(spawn_daemon_inner( - rpc_socket, - files_dir, + let daemon_config = DaemonConfig { + rpc_socket_path: rpc_socket, + log_dir: Some(files_dir.clone()), + resource_dir: files_dir.clone(), + settings_dir: files_dir, cache_dir, - daemon_command_channel, android_context, endpoint, - ))?; + }; + + let running_daemon = + runtime.block_on(spawn_daemon_inner(daemon_config, daemon_command_channel))?; Ok(DaemonContext { runtime, @@ -189,27 +186,14 @@ fn spawn_daemon( } async fn spawn_daemon_inner( - rpc_socket: PathBuf, - files_dir: PathBuf, - cache_dir: PathBuf, + daemon_config: DaemonConfig, daemon_command_channel: DaemonCommandChannel, - android_context: AndroidContext, - endpoint: ApiEndpoint, ) -> Result, Error> { - cleanup_old_rpc_socket(&rpc_socket).await; + cleanup_old_rpc_socket(&daemon_config.rpc_socket_path).await; - let daemon = Daemon::start( - Some(files_dir.clone()), - files_dir.clone(), - files_dir, - cache_dir, - rpc_socket, - daemon_command_channel, - endpoint, - android_context, - ) - .await - .map_err(Error::InitializeDaemon)?; + let daemon = Daemon::start(daemon_config, daemon_command_channel) + .await + .map_err(Error::InitializeDaemon)?; let running_daemon = tokio::spawn(async move { match daemon.run().await {