From 54c4d869ed2d4b00be7fbeaac12aa2982776e777 Mon Sep 17 00:00:00 2001 From: fufesou Date: Thu, 30 Oct 2025 15:34:43 +0800 Subject: [PATCH] refact: tls native-tls fallback rustls-tls Signed-off-by: fufesou --- Cargo.toml | 24 ++--- src/config.rs | 12 ++- src/lib.rs | 3 +- src/proxy.rs | 241 +++++++++++++++++++++++++++++++++++++---------- src/tls.rs | 121 ++++++++++++++++++++++++ src/verifier.rs | 73 +++++++++++++- src/websocket.rs | 187 +++++++++++++++++++++++++++--------- 7 files changed, 545 insertions(+), 116 deletions(-) create mode 100644 src/tls.rs diff --git a/Cargo.toml b/Cargo.toml index ea42287..030d13d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,30 +48,24 @@ url = "2.5" sha2 = "0.10" whoami = "1.5" -[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] -mac_address = "1.1" -default_net = { git = "https://github.com/rustdesk-org/default_net" } -machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" } - -[target.'cfg(not(any(target_os = "macos", target_os = "windows")))'.dependencies] tokio-rustls = { version = "0.26", features = [ "logging", "tls12", "ring", ], default-features = false } +tokio-native-tls = "0.3" +tokio-tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] } +tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] } +rustls-platform-verifier = "0.6" rustls-pki-types = "1.11" -tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots", "rustls-tls-webpki-roots"] } -tungstenite = { version = "0.26", features = ["rustls-tls-native-roots", "rustls-tls-webpki-roots"] } rustls-native-certs = "0.8" webpki-roots = "1.0" +async-recursion = "1.1" -[target.'cfg(any(target_os = "android", target_os = "ios"))'.dependencies] -rustls-platform-verifier = "0.6" - -[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies] -tokio-native-tls = "0.3" -tokio-tungstenite = { version = "0.26", features = ["native-tls"] } -tungstenite = { version = "0.26", features = ["native-tls"] } +[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] +mac_address = "1.1" +default_net = { git = "https://github.com/rustdesk-org/default_net" } +machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" } [build-dependencies] protobuf-codegen = { version = "3.7" } diff --git a/src/config.rs b/src/config.rs index dd9ed3f..c8673d8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2417,6 +2417,11 @@ pub fn use_ws() -> bool { option2bool(option, &Config::get_option(option)) } +pub fn allow_insecure_tls_fallback() -> bool { + let option = keys::OPTION_ALLOW_INSECURE_TLS_FALLBACK; + option2bool(option, &Config::get_option(option)) +} + pub mod keys { pub const OPTION_VIEW_ONLY: &str = "view_only"; pub const OPTION_SHOW_MONITORS_TOOLBAR: &str = "show_monitors_toolbar"; @@ -2513,14 +2518,16 @@ pub mod keys { pub const OPTION_TRACKPAD_SPEED: &str = "trackpad-speed"; pub const OPTION_REGISTER_DEVICE: &str = "register-device"; pub const OPTION_RELAY_SERVER: &str = "relay-server"; + pub const OPTION_DISABLE_UDP: &str = "disable-udp"; + pub const OPTION_ALLOW_INSECURE_TLS_FALLBACK: &str = "allow-insecure-tls-fallback"; pub const OPTION_SHOW_VIRTUAL_MOUSE: &str = "show-virtual-mouse"; // joystick is the virtual mouse. // So `OPTION_SHOW_VIRTUAL_MOUSE` should also be set if `OPTION_SHOW_VIRTUAL_JOYSTICK` is set. pub const OPTION_SHOW_VIRTUAL_JOYSTICK: &str = "show-virtual-joystick"; + pub const OPTION_ENABLE_FLUTTER_HTTP_ON_RUST: &str = "enable-flutter-http-on-rust"; // built-in options pub const OPTION_DISPLAY_NAME: &str = "display-name"; - pub const OPTION_DISABLE_UDP: &str = "disable-udp"; pub const OPTION_PRESET_DEVICE_GROUP_NAME: &str = "preset-device-group-name"; pub const OPTION_PRESET_USERNAME: &str = "preset-user-name"; pub const OPTION_PRESET_STRATEGY_NAME: &str = "preset-strategy-name"; @@ -2651,6 +2658,7 @@ pub mod keys { OPTION_TOUCH_MODE, OPTION_SHOW_VIRTUAL_MOUSE, OPTION_SHOW_VIRTUAL_JOYSTICK, + OPTION_ENABLE_FLUTTER_HTTP_ON_RUST, ]; // DEFAULT_SETTINGS, OVERWRITE_SETTINGS pub const KEYS_SETTINGS: &[&str] = &[ @@ -2703,12 +2711,12 @@ pub mod keys { OPTION_ENABLE_ANDROID_SOFTWARE_ENCODING_HALF_SCALE, OPTION_ENABLE_TRUSTED_DEVICES, OPTION_RELAY_SERVER, + OPTION_DISABLE_UDP, ]; // BUILDIN_SETTINGS pub const KEYS_BUILDIN_SETTINGS: &[&str] = &[ OPTION_DISPLAY_NAME, - OPTION_DISABLE_UDP, OPTION_PRESET_DEVICE_GROUP_NAME, OPTION_PRESET_USERNAME, OPTION_PRESET_STRATEGY_NAME, diff --git a/src/lib.rs b/src/lib.rs index 1f2d53d..5bd9da1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,8 +63,9 @@ pub mod websocket; pub use rustls_platform_verifier; pub use stream::Stream; pub use whoami; -#[cfg(not(any(target_os = "macos", target_os = "windows")))] +pub mod tls; pub mod verifier; +pub use async_recursion; pub type SessionID = uuid::Uuid; diff --git a/src/proxy.rs b/src/proxy.rs index e3ca2bc..a69c110 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -3,16 +3,15 @@ use std::{ net::{SocketAddr, ToSocketAddrs}, }; +use anyhow::bail; +use async_recursion::async_recursion; use base64::{engine::general_purpose, Engine}; use httparse::{Error as HttpParseError, Response, EMPTY_HEADER}; -use log::info; use thiserror::Error as ThisError; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}; -#[cfg(any(target_os = "windows", target_os = "macos"))] use tokio_native_tls::{native_tls, TlsConnector, TlsStream}; -#[cfg(not(any(target_os = "windows", target_os = "macos")))] -use tokio_rustls::{client::TlsStream, TlsConnector}; -use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr}; +use tokio_rustls::{client::TlsStream as RustlsTlsStream, TlsConnector as RustlsTlsConnector}; +use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, TargetAddr}; use tokio_util::codec::Framed; use url::Url; @@ -20,6 +19,7 @@ use crate::{ bytes_codec::BytesCodec, config::Socks5Server, tcp::{DynTcpStream, FramedStream}, + tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType}, ResultType, }; @@ -45,7 +45,6 @@ pub enum ProxyError { HttpCode200(u16), #[error("The proxy address resolution failed: {0}")] AddressResolutionFailed(String), - #[cfg(any(target_os = "windows", target_os = "macos"))] #[error("The native tls error: {0}")] NativeTlsError(#[from] tokio_native_tls::native_tls::Error), } @@ -226,7 +225,7 @@ impl ProxyScheme { Ok(scheme) } pub async fn socket_addrs(&self) -> Result { - info!("Resolving socket address"); + log::trace!("Resolving socket address"); match self { ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await, ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await, @@ -356,37 +355,50 @@ impl Proxy { self } + async fn new_stream( + &self, + local: SocketAddr, + proxy: SocketAddr, + ) -> ResultType { + let stream = super::timeout( + self.ms_timeout, + crate::tcp::new_socket(local, true)?.connect(proxy), + ) + .await??; + stream.set_nodelay(true).ok(); + Ok(stream) + } + pub async fn connect<'t, T>( - self, + &self, target: T, local_addr: Option, ) -> ResultType where T: IntoTargetAddr<'t>, { - info!("Connect to proxy server"); + log::trace!("Connect to proxy server"); let proxy = self.proxy_addrs().await?; + let target_addr = target + .into_target_addr() + .map_err(|e| ProxyError::TargetParseError(e.to_string()))?; + let local = if let Some(addr) = local_addr { addr } else { crate::config::Config::get_any_listen_addr(proxy.is_ipv4()) }; - let stream = super::timeout( - self.ms_timeout, - crate::tcp::new_socket(local, true)?.connect(proxy), - ) - .await??; - stream.set_nodelay(true).ok(); - + let stream = self.new_stream(local, proxy).await?; let addr = stream.local_addr()?; return match self.intercept { ProxyScheme::Http { .. } => { - info!("Connect to remote http proxy server: {}", proxy); + log::trace!("Connect to remote http proxy server: {}", proxy); let stream = - super::timeout(self.ms_timeout, self.http_connect(stream, target)).await??; + super::timeout(self.ms_timeout, self.http_connect(stream, &target_addr)) + .await??; Ok(FramedStream( Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), addr, @@ -395,24 +407,54 @@ impl Proxy { )) } ProxyScheme::Https { .. } => { - info!("Connect to remote https proxy server: {}", proxy); - let stream = - super::timeout(self.ms_timeout, self.https_connect(stream, target)).await??; + log::trace!("Connect to remote https proxy server: {}", proxy); + let url = format!("https://{}", self.intercept.get_host_and_port()?); + let tls_type = get_cached_tls_type(&url); + let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url); + let stream = match tls_type.unwrap_or(TlsType::NativeTls) { + TlsType::NativeTls => { + self.https_connect_nativetls_wrap_danger( + &url, + local, + proxy, + Some(stream), + &target_addr, + tls_type.is_some(), + danger_accept_invalid_cert, + danger_accept_invalid_cert, + ) + .await? + } + TlsType::Rustls => { + self.https_connect_rustls_wrap_danger( + &url, + local, + proxy, + &target_addr, + danger_accept_invalid_cert, + ) + .await? + } + _ => { + // Unreachable + crate::bail!("Unreachable, TlsType::Plain in HTTPS proxy"); + } + }; Ok(FramedStream( - Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + Framed::new(stream, BytesCodec::new()), addr, None, 0, )) } ProxyScheme::Socks5 { .. } => { - info!("Connect to remote socket5 proxy server: {}", proxy); + log::trace!("Connect to remote socket5 proxy server: {}", proxy); let stream = if let Some(auth) = self.intercept.maybe_auth() { super::timeout( self.ms_timeout, Socks5Stream::connect_with_password_and_socket( stream, - target, + target_addr, &auth.user_name, &auth.password, ), @@ -421,7 +463,7 @@ impl Proxy { } else { super::timeout( self.ms_timeout, - Socks5Stream::connect_with_socket(stream, target), + Socks5Stream::connect_with_socket(stream, target_addr), ) .await?? }; @@ -435,32 +477,133 @@ impl Proxy { }; } - #[cfg(any(target_os = "windows", target_os = "macos"))] - pub async fn https_connect<'a, Input, T>( - self, + #[async_recursion] + async fn https_connect_nativetls_wrap_danger<'a>( + &self, + url: &str, + local: SocketAddr, + proxy: SocketAddr, + stream: Option, + target_addr: &TargetAddr<'a>, + is_tls_type_cached: bool, + danger_accept_invalid_cert: Option, + origin_danger_accept_invalid_cert: Option, + ) -> ResultType { + let stream = stream.unwrap_or(self.new_stream(local, proxy).await?); + match super::timeout( + self.ms_timeout, + self.https_connect_nativetls( + stream, + target_addr, + danger_accept_invalid_cert.unwrap_or(false), + ), + ) + .await? + { + Ok(s) => { + upsert_tls_cache( + &url, + TlsType::NativeTls, + danger_accept_invalid_cert.unwrap_or(false), + ); + Ok(DynTcpStream(Box::new(s))) + } + Err(ProxyError::NativeTlsError(e)) => { + let s = if danger_accept_invalid_cert.is_none() { + log::warn!( + "Falling back to native-tls (accept invalid cert) for HTTPS proxy server." + ); + self.https_connect_nativetls_wrap_danger( + &url, + local, + proxy, + None, + target_addr, + is_tls_type_cached, + Some(true), + origin_danger_accept_invalid_cert, + ) + .await? + } else if !is_tls_type_cached { + log::warn!("Falling back to rustls for HTTPS proxy server."); + self.https_connect_rustls_wrap_danger( + &url, + local, + proxy, + &target_addr, + origin_danger_accept_invalid_cert, + ) + .await? + } else { + log::error!( + "Failed to connect to HTTPS proxy server with native-tls: {:?}.", + e + ); + bail!(e) + }; + Ok(s) + } + Err(e) => { + log::error!("Failed to connect to HTTPS proxy server: {:?}.", e); + bail!(e) + } + } + } + + pub async fn https_connect_nativetls<'a, Input>( + &self, io: Input, - target: T, + target_addr: &TargetAddr<'a>, + danger_accept_invalid_cert: bool, ) -> Result>, ProxyError> where Input: AsyncRead + AsyncWrite + Unpin, - T: IntoTargetAddr<'a>, { - let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?); + let mut tls_connector_builder = native_tls::TlsConnector::builder(); + if danger_accept_invalid_cert { + tls_connector_builder.danger_accept_invalid_certs(true); + } + let tls_connector = TlsConnector::from(tls_connector_builder.build()?); let stream = tls_connector .connect(&self.intercept.get_domain()?, io) .await?; - self.http_connect(stream, target).await + self.http_connect(stream, target_addr).await } - #[cfg(not(any(target_os = "windows", target_os = "macos")))] - pub async fn https_connect<'a, Input, T>( - self, + async fn https_connect_rustls_wrap_danger<'a>( + &self, + url: &str, + local: SocketAddr, + proxy: SocketAddr, + target_addr: &TargetAddr<'a>, + danger_accept_invalid_cert: Option, + ) -> ResultType { + let stream = self.new_stream(local, proxy).await?; + let s = super::timeout( + self.ms_timeout, + self.https_connect_rustls( + stream, + &target_addr, + danger_accept_invalid_cert.unwrap_or(false), + ), + ) + .await??; + upsert_tls_cache( + url, + TlsType::Rustls, + danger_accept_invalid_cert.unwrap_or(false), + ); + Ok(DynTcpStream(Box::new(s))) + } + + pub async fn https_connect_rustls<'a, Input>( + &self, io: Input, - target: T, - ) -> Result>, ProxyError> + target_addr: &TargetAddr<'a>, + danger_accept_invalid_cert: bool, + ) -> Result>, ProxyError> where Input: AsyncRead + AsyncWrite + Unpin, - T: IntoTargetAddr<'a>, { use std::convert::TryFrom; @@ -468,24 +611,23 @@ impl Proxy { let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str()) .map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))? .to_owned(); - let client_config = crate::verifier::client_config() + let client_config = crate::verifier::client_config(danger_accept_invalid_cert) .map_err(|e| ProxyError::IoError(std::io::Error::other(e)))?; - let tls_connector = TlsConnector::from(std::sync::Arc::new(client_config)); + let tls_connector = RustlsTlsConnector::from(std::sync::Arc::new(client_config)); let stream = tls_connector.connect(domain, io).await?; - self.http_connect(stream, target).await + self.http_connect(stream, target_addr).await } - pub async fn http_connect<'a, Input, T>( - self, + pub async fn http_connect<'a, Input>( + &self, io: Input, - target: T, + target_addr: &TargetAddr<'a>, ) -> Result, ProxyError> where Input: AsyncRead + AsyncWrite + Unpin, - T: IntoTargetAddr<'a>, { let mut stream = BufStream::new(io); - let (domain, port) = get_domain_and_port(target)?; + let (domain, port) = get_domain_and_port(target_addr)?; let request = self.make_request(&domain, port); stream.write_all(request.as_bytes()).await?; @@ -510,13 +652,10 @@ impl Proxy { } } -fn get_domain_and_port<'a, T: IntoTargetAddr<'a>>(target: T) -> Result<(String, u16), ProxyError> { - let target_addr = target - .into_target_addr() - .map_err(|e| ProxyError::TargetParseError(e.to_string()))?; +fn get_domain_and_port<'a>(target_addr: &TargetAddr<'a>) -> Result<(String, u16), ProxyError> { match target_addr { tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())), - tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), port)), + tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), *port)), } } diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..b086236 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,121 @@ +use std::{collections::HashMap, sync::RwLock}; + +use crate::config::allow_insecure_tls_fallback; + +#[derive(Debug, Clone, Copy)] +pub enum TlsType { + Plain, + NativeTls, + Rustls, +} + +lazy_static::lazy_static! { + static ref URL_TLS_TYPE: RwLock> = RwLock::new(HashMap::new()); + static ref URL_TLS_DANGER_ACCEPT_INVALID_CERTS: RwLock> = RwLock::new(HashMap::new()); +} + +#[inline] +pub fn is_plain(url: &str) -> bool { + url.starts_with("ws://") || url.starts_with("http://") +} + +// Extract domain from URL. +// e.g., "https://example.com/path" -> "example.com" +// "https://example.com:8080/path" -> "example.com:8080" +// See the tests for more examples. +#[inline] +fn get_domain_and_port_from_url(url: &str) -> &str { + // Remove scheme (e.g., http://, https://, ws://, wss://) + let scheme_end = url.find("://").map(|pos| pos + 3).unwrap_or(0); + let url2 = &url[scheme_end..]; + // If userinfo is present, domain is after last '@' + let after_at = match url2.rfind('@') { + Some(pos) => &url2[pos + 1..], + None => url2, + }; + // Find the end of domain (before '/' or '?') + let domain_end = after_at.find(&['/', '?'][..]).unwrap_or(after_at.len()); + &after_at[..domain_end] +} + +#[inline] +pub fn upsert_tls_cache(url: &str, tls_type: TlsType, danger_accept_invalid_cert: bool) { + if is_plain(url) { + return; + } + + let domain_port = get_domain_and_port_from_url(url); + // Use curly braces to ensure the lock is released immediately. + { + URL_TLS_TYPE + .write() + .unwrap() + .insert(domain_port.to_string(), tls_type); + } + { + URL_TLS_DANGER_ACCEPT_INVALID_CERTS + .write() + .unwrap() + .insert(domain_port.to_string(), danger_accept_invalid_cert); + } +} + +#[inline] +pub fn reset_tls_cache() { + // Use curly braces to ensure the lock is released immediately. + { + URL_TLS_TYPE.write().unwrap().clear(); + } + { + URL_TLS_DANGER_ACCEPT_INVALID_CERTS.write().unwrap().clear(); + } +} + +#[inline] +pub fn get_cached_tls_type(url: &str) -> Option { + if is_plain(url) { + return Some(TlsType::Plain); + } + let domain_port = get_domain_and_port_from_url(url); + URL_TLS_TYPE.read().unwrap().get(domain_port).cloned() +} + +#[inline] +pub fn get_cached_tls_accept_invalid_cert(url: &str) -> Option { + if !allow_insecure_tls_fallback() { + return Some(false); + } + + if is_plain(url) { + return Some(false); + } + let domain_port = get_domain_and_port_from_url(url); + URL_TLS_DANGER_ACCEPT_INVALID_CERTS + .read() + .unwrap() + .get(domain_port) + .cloned() +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_domain_and_port_from_url() { + for (url, expected_domain_port) in vec![ + ("http://example.com", "example.com"), + ("https://example.com", "example.com"), + ("ws://example.com/path", "example.com"), + ("wss://example.com:8080/path", "example.com:8080"), + ("https://user:pass@example.com", "example.com"), + ("https://example.com?query=param", "example.com"), + ("https://example.com:8443?query=param", "example.com:8443"), + ("ftp://example.com/resource", "example.com"), // ftp scheme + ("example.com/path", "example.com"), // no scheme + ("example.com:8080/path", "example.com:8080"), + ] { + let domain_port = get_domain_and_port_from_url(url); + assert_eq!(domain_port, expected_domain_port); + } + } +} diff --git a/src/verifier.rs b/src/verifier.rs index 55b7fd9..9f62854 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -1,14 +1,65 @@ use crate::ResultType; -#[cfg(any(target_os = "android", target_os = "ios"))] use rustls_pki_types::{ServerName, UnixTime}; use std::sync::Arc; use tokio_rustls::rustls::{self, client::WebPkiServerVerifier, ClientConfig}; -#[cfg(any(target_os = "android", target_os = "ios"))] use tokio_rustls::rustls::{ client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, DigitallySignedStruct, Error as TLSError, SignatureScheme, }; +// https://github.com/seanmonstar/reqwest/blob/fd61bc93e6f936454ce0b978c6f282f06eee9287/src/tls.rs#L608 +#[derive(Debug)] +pub(crate) struct NoVerifier; + +impl ServerCertVerifier for NoVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls_pki_types::CertificateDer, + _intermediates: &[rustls_pki_types::CertificateDer], + _server_name: &ServerName, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::RSA_PKCS1_SHA1, + SignatureScheme::ECDSA_SHA1_Legacy, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, + SignatureScheme::ECDSA_NISTP521_SHA512, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::ED25519, + SignatureScheme::ED448, + ] + } +} + /// A certificate verifier that tries a primary verifier first, /// and falls back to a platform verifier if the primary fails. #[cfg(any(target_os = "android", target_os = "ios"))] @@ -149,7 +200,15 @@ fn webpki_server_verifier( Ok(verifier) } -pub fn client_config() -> ResultType { +pub fn client_config(danger_accept_invalid_cert: bool) -> ResultType { + if danger_accept_invalid_cert { + client_config_danger() + } else { + client_config_safe() + } +} + +pub fn client_config_safe() -> ResultType { // Use the default builder which uses the default protocol versions and crypto provider. // The with_protocol_versions API has been removed in rustls master branch: // https://github.com/rustls/rustls/pull/2599 @@ -188,3 +247,11 @@ pub fn client_config() -> ResultType { Ok(config) } } + +pub fn client_config_danger() -> ResultType { + let config = ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoVerifier)) + .with_no_client_auth(); + Ok(config) +} diff --git a/src/websocket.rs b/src/websocket.rs index 011a584..dc1e981 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,22 +1,29 @@ use crate::{ - config::keys::OPTION_RELAY_SERVER, - config::{use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT}, + config::{ + keys::OPTION_RELAY_SERVER, use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT, + }, protobuf::Message, socket_client::split_host_port, sodiumoxide::crypto::secretbox::Key, tcp::Encrypt, + tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType}, ResultType, }; +use anyhow::bail; +use async_recursion::async_recursion; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use std::{ io::{Error, ErrorKind}, net::SocketAddr, + sync::Arc, time::Duration, }; use tokio::{net::TcpStream, time::timeout}; +use tokio_native_tls::native_tls::TlsConnector; use tokio_tungstenite::{ - connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, + connect_async_tls_with_config, tungstenite::protocol::Message as WsMessage, Connector, + MaybeTlsStream, WebSocketStream, }; use tungstenite::client::IntoClientRequest; use tungstenite::protocol::Role; @@ -29,29 +36,21 @@ pub struct WsFramedStream { } impl WsFramedStream { - pub async fn new>( - url: T, - _local_addr: Option, - _proxy_conf: Option<&Socks5Server>, - ms_timeout: u64, - ) -> ResultType { - let url_str = url.as_ref(); - - // to-do: websocket proxy. - - let request = url_str - .into_client_request() - .map_err(|e| Error::new(ErrorKind::Other, e))?; - - let stream; - #[cfg(any(target_os = "android", target_os = "ios"))] - { - let is_wss = url_str.starts_with("wss://"); - if is_wss { - use std::sync::Arc; - use tokio_tungstenite::{connect_async_tls_with_config, Connector}; - - let connector = match crate::verifier::client_config() { + #[inline] + fn get_connector( + tls_type: &TlsType, + danger_accept_invalid_certs: bool, + ) -> ResultType> { + match tls_type { + TlsType::Plain => Ok(Some(Connector::Plain)), + TlsType::NativeTls => { + let connector = TlsConnector::builder() + .danger_accept_invalid_certs(danger_accept_invalid_certs) + .build()?; + Ok(Some(Connector::NativeTls(connector))) + } + TlsType::Rustls => { + let connector = match crate::verifier::client_config(danger_accept_invalid_certs) { Ok(client_config) => Some(Connector::Rustls(Arc::new(client_config))), Err(e) => { log::warn!( @@ -61,30 +60,130 @@ impl WsFramedStream { None } }; - let (s, _) = timeout( - Duration::from_millis(ms_timeout), - connect_async_tls_with_config(request, None, false, connector), - ) - .await??; - stream = s; - } else { - let (s, _) = - timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??; - stream = s; + Ok(connector) } } - #[cfg(not(any(target_os = "android", target_os = "ios")))] - { - let (s, _) = - timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??; - stream = s; - } + } + async fn connect( + url: &str, + ms_timeout: u64, + ) -> ResultType>> { + // to-do: websocket proxy. + + let tls_type = get_cached_tls_type(url); + let is_tls_type_cached = tls_type.is_some(); + let tls_type = tls_type.unwrap_or(TlsType::NativeTls); + let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url); + Self::try_connect( + url, + ms_timeout, + tls_type, + is_tls_type_cached, + danger_accept_invalid_cert, + danger_accept_invalid_cert, + ) + .await + } + + #[async_recursion] + async fn try_connect( + url: &str, + ms_timeout: u64, + tls_type: TlsType, + is_tls_type_cached: bool, + danger_accept_invalid_cert: Option, + original_danger_accept_invalid_certs: Option, + ) -> ResultType>> { + let ws_config = None; + let disable_nagle = false; + let request = url + .into_client_request() + .map_err(|e| Error::new(ErrorKind::Other, e))?; + let connector = + Self::get_connector(&tls_type, danger_accept_invalid_cert.unwrap_or(false))?; + match timeout( + Duration::from_millis(ms_timeout), + connect_async_tls_with_config(request, ws_config, disable_nagle, connector), + ) + .await? + { + Ok((ws_stream, _)) => { + upsert_tls_cache(url, tls_type, danger_accept_invalid_cert.unwrap_or(false)); + Ok(ws_stream) + } + Err(e) => match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) { + (TlsType::NativeTls, _, None) => { + log::warn!( + "WebSocket connection with native-tls failed, try accept invalid certs: {}, {:?}", + url, + e + ); + Self::try_connect( + url, + ms_timeout, + tls_type, + is_tls_type_cached, + Some(true), + original_danger_accept_invalid_certs, + ) + .await + } + (TlsType::NativeTls, false, Some(_)) => { + log::warn!( + "WebSocket connection with native-tls failed, try rustls: {}, {:?}", + url, + e + ); + Self::try_connect( + url, + ms_timeout, + TlsType::Rustls, + is_tls_type_cached, + original_danger_accept_invalid_certs, + original_danger_accept_invalid_certs, + ) + .await + } + (TlsType::Rustls, _, None) => { + log::warn!( + "WebSocket connection with rustls failed, try accept invalid certs: {}, {:?}", + url, + e + ); + Self::try_connect( + url, + ms_timeout, + tls_type, + is_tls_type_cached, + Some(true), + original_danger_accept_invalid_certs, + ) + .await + } + _ => { + log::error!( + "WebSocket connection failed with tls_type {:?}: {}, {:?}", + tls_type, + url, + e + ); + bail!(e) + } + }, + } + } + + pub async fn new>( + url: T, + _local_addr: Option, + _proxy_conf: Option<&Socks5Server>, + ms_timeout: u64, + ) -> ResultType { + let stream = Self::connect(url.as_ref(), ms_timeout).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - #[cfg(any(target_os = "macos", target_os = "windows"))] MaybeTlsStream::NativeTls(tls) => tls.get_ref().get_ref().get_ref().peer_addr()?, - #[cfg(not(any(target_os = "macos", target_os = "windows")))] MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), };