use crate::{ config::{ keys::OPTION_RELAY_SERVER, use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT, }, protobuf::Message, socket_client::split_host_port, sodiumoxide::crypto::secretbox::Key, tcp::Encrypt, tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType}, ResultType, }; use anyhow::bail; use async_recursion::async_recursion; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use std::{ io::{Error, ErrorKind}, net::SocketAddr, sync::Arc, time::Duration, }; use tokio::{net::TcpStream, time::timeout}; use tokio_native_tls::native_tls::TlsConnector; use tokio_tungstenite::{ connect_async_tls_with_config, tungstenite::protocol::Message as WsMessage, Connector, MaybeTlsStream, WebSocketStream, }; use tungstenite::client::IntoClientRequest; use tungstenite::protocol::Role; pub struct WsFramedStream { stream: WebSocketStream>, addr: SocketAddr, encrypt: Option, send_timeout: u64, } impl WsFramedStream { #[inline] fn get_connector( tls_type: &TlsType, danger_accept_invalid_certs: bool, ) -> ResultType> { match tls_type { TlsType::Plain => Ok(Some(Connector::Plain)), TlsType::NativeTls => { let connector = TlsConnector::builder() .danger_accept_invalid_certs(danger_accept_invalid_certs) .build()?; Ok(Some(Connector::NativeTls(connector))) } TlsType::Rustls => { let connector = match crate::verifier::client_config(danger_accept_invalid_certs) { Ok(client_config) => Some(Connector::Rustls(Arc::new(client_config))), Err(e) => { log::warn!( "Failed to get client config: {:?}, fallback to default connector", e ); None } }; Ok(connector) } } } async fn connect( url: &str, ms_timeout: u64, ) -> ResultType>> { // to-do: websocket proxy. let tls_type = get_cached_tls_type(url); let is_tls_type_cached = tls_type.is_some(); let tls_type = tls_type.unwrap_or(TlsType::Rustls); let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url); Self::try_connect( url, ms_timeout, tls_type, is_tls_type_cached, danger_accept_invalid_cert, danger_accept_invalid_cert, ) .await } #[async_recursion] async fn try_connect( url: &str, ms_timeout: u64, tls_type: TlsType, is_tls_type_cached: bool, danger_accept_invalid_cert: Option, original_danger_accept_invalid_certs: Option, ) -> ResultType>> { let ws_config = None; let disable_nagle = false; let request = url .into_client_request() .map_err(|e| Error::new(ErrorKind::Other, e))?; let connector = Self::get_connector(&tls_type, danger_accept_invalid_cert.unwrap_or(false))?; match timeout( Duration::from_millis(ms_timeout), connect_async_tls_with_config(request, ws_config, disable_nagle, connector), ) .await? { Ok((ws_stream, _)) => { upsert_tls_cache(url, tls_type, danger_accept_invalid_cert.unwrap_or(false)); Ok(ws_stream) } Err(e) => match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) { (TlsType::Rustls, _, None) => { log::warn!( "WebSocket connection with rustls-tls failed, try accept invalid certs: {}, {:?}", url, e ); Self::try_connect( url, ms_timeout, tls_type, is_tls_type_cached, Some(true), original_danger_accept_invalid_certs, ) .await } (TlsType::Rustls, false, Some(_)) => { log::warn!( "WebSocket connection with rustls-tls failed, try native-tls: {}, {:?}", url, e ); Self::try_connect( url, ms_timeout, TlsType::NativeTls, is_tls_type_cached, original_danger_accept_invalid_certs, original_danger_accept_invalid_certs, ) .await } (TlsType::NativeTls, _, None) => { log::warn!( "WebSocket connection with native-tls failed, try accept invalid certs: {}, {:?}", url, e ); Self::try_connect( url, ms_timeout, tls_type, is_tls_type_cached, Some(true), original_danger_accept_invalid_certs, ) .await } _ => { log::error!( "WebSocket connection failed with tls_type {:?}: {}, {:?}", tls_type, url, e ); bail!(e) } }, } } pub async fn new>( url: T, _local_addr: Option, _proxy_conf: Option<&Socks5Server>, ms_timeout: u64, ) -> ResultType { let stream = Self::connect(url.as_ref(), ms_timeout).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, MaybeTlsStream::NativeTls(tls) => tls.get_ref().get_ref().get_ref().peer_addr()?, MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; let ws = Self { stream, addr, encrypt: None, send_timeout: ms_timeout, }; Ok(ws) } #[inline] pub fn set_raw(&mut self) { self.encrypt = None; } #[inline] pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; Ok(Self { stream: ws_stream, addr, encrypt: None, send_timeout: 0, }) } #[inline] pub fn local_addr(&self) -> SocketAddr { self.addr } #[inline] pub fn set_send_timeout(&mut self, ms: u64) { self.send_timeout = ms; } #[inline] pub fn set_key(&mut self, key: Key) { self.encrypt = Some(Encrypt::new(key)); } #[inline] pub fn is_secured(&self) -> bool { self.encrypt.is_some() } #[inline] pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { self.send_raw(msg.write_to_bytes()?).await } #[inline] pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { let mut msg = msg; if let Some(key) = self.encrypt.as_mut() { msg = key.enc(&msg); } self.send_bytes(Bytes::from(msg)).await } pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { let msg = WsMessage::Binary(bytes); if self.send_timeout > 0 { timeout( Duration::from_millis(self.send_timeout), self.stream.send(msg), ) .await?? } else { self.stream.send(msg).await? }; Ok(()) } #[inline] pub async fn next(&mut self) -> Option> { while let Some(msg) = self.stream.next().await { let msg = match msg { Ok(msg) => msg, Err(e) => { log::error!("{}", e); return Some(Err(Error::new( ErrorKind::Other, format!("WebSocket protocol error: {}", e), ))); } }; match msg { WsMessage::Binary(data) => { let mut bytes = BytesMut::from(&data[..]); if let Some(key) = self.encrypt.as_mut() { if let Err(err) = key.dec(&mut bytes) { return Some(Err(err)); } } return Some(Ok(bytes)); } WsMessage::Text(text) => { let bytes = BytesMut::from(text.as_bytes()); return Some(Ok(bytes)); } WsMessage::Close(_) => { return None; } _ => { continue; } } } None } #[inline] pub async fn next_timeout(&mut self, ms: u64) -> Option> { match timeout(Duration::from_millis(ms), self.next()).await { Ok(res) => res, Err(_) => None, } } } pub fn is_ws_endpoint(endpoint: &str) -> bool { endpoint.starts_with("ws://") || endpoint.starts_with("wss://") } /** * Core function to convert an endpoint to WebSocket format * * Converts between different address formats: * 1. IPv4 address with/without port -> ws://ipv4:port * 2. IPv6 address with/without port -> ws://[ipv6]:port * 3. Domain with/without port -> ws(s)://domain/ws/path * * @param endpoint The endpoint to convert * @return The converted WebSocket endpoint */ pub fn check_ws(endpoint: &str) -> String { if !use_ws() { return endpoint.to_string(); } if endpoint.is_empty() { return endpoint.to_string(); } if is_ws_endpoint(endpoint) { return endpoint.to_string(); } let Some((endpoint_host, endpoint_port)) = split_host_port(endpoint) else { debug_assert!(false, "endpoint doesn't have port"); return endpoint.to_string(); }; let custom_rendezvous_server = Config::get_rendezvous_server(); let relay_server = Config::get_option(OPTION_RELAY_SERVER); let rendezvous_port = split_host_port(&custom_rendezvous_server) .map(|(_, p)| p) .unwrap_or(RENDEZVOUS_PORT); let relay_port = split_host_port(&relay_server) .map(|(_, p)| p) .unwrap_or(RELAY_PORT); let (relay, dst_port) = if endpoint_port == rendezvous_port { // rendezvous (false, endpoint_port + 2) } else if endpoint_port == rendezvous_port - 1 { // online (false, endpoint_port + 3) } else if endpoint_port == relay_port || endpoint_port == rendezvous_port + 1 { // relay // https://github.com/rustdesk/rustdesk/blob/6ffbcd1375771f2482ec4810680623a269be70f1/src/rendezvous_mediator.rs#L615 // https://github.com/rustdesk/rustdesk-server/blob/235a3c326ceb665e941edb50ab79faa1208f7507/src/relay_server.rs#L83, based on relay port. (true, endpoint_port + 2) } else { // fallback relay // for controlling side, relay server is passed from the controlled side, not related to local config. (true, endpoint_port + 2) }; let (address, is_domain) = if crate::is_ip_str(endpoint) { (format!("{}:{}", endpoint_host, dst_port), false) } else { let domain_path = if relay { "/ws/relay" } else { "/ws/id" }; (format!("{}{}", endpoint_host, domain_path), true) }; let protocol = if is_domain { let api_server = Config::get_option("api-server"); if api_server.starts_with("https") { "wss" } else { "ws" } } else { "ws" }; format!("{}://{}", protocol, address) } #[cfg(test)] mod tests { use super::*; use crate::config::{keys, Config}; #[test] fn test_check_ws() { // enable websocket Config::set_option(keys::OPTION_ALLOW_WEBSOCKET.to_string(), "Y".to_string()); // not set custom-rendezvous-server Config::set_option("custom-rendezvous-server".to_string(), "".to_string()); Config::set_option("relay-server".to_string(), "".to_string()); Config::set_option("api-server".to_string(), "".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); assert_eq!(check_ws("rustdesk.com:21115"), "ws://rustdesk.com/ws/id"); assert_eq!(check_ws("rustdesk.com:21116"), "ws://rustdesk.com/ws/id"); assert_eq!(check_ws("rustdesk.com:21117"), "ws://rustdesk.com/ws/relay"); // set relay-server without port Config::set_option("relay-server".to_string(), "127.0.0.1".to_string()); Config::set_option( "api-server".to_string(), "https://api.rustdesk.com".to_string(), ); assert_eq!( check_ws("[0:0:0:0:0:0:0:1]:21115"), "ws://[0:0:0:0:0:0:0:1]:21118" ); assert_eq!( check_ws("[0:0:0:0:0:0:0:1]:21116"), "ws://[0:0:0:0:0:0:0:1]:21118" ); assert_eq!( check_ws("[0:0:0:0:0:0:0:1]:21117"), "ws://[0:0:0:0:0:0:0:1]:21119" ); assert_eq!(check_ws("rustdesk.com:21115"), "wss://rustdesk.com/ws/id"); assert_eq!(check_ws("rustdesk.com:21116"), "wss://rustdesk.com/ws/id"); assert_eq!( check_ws("rustdesk.com:21117"), "wss://rustdesk.com/ws/relay" ); // set relay-server with default port Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with custom port Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string()); assert_eq!(check_ws("rustdesk.com:21115"), "wss://rustdesk.com/ws/id"); assert_eq!(check_ws("rustdesk.com:21116"), "wss://rustdesk.com/ws/id"); assert_eq!( check_ws("rustdesk.com:34567"), "wss://rustdesk.com/ws/relay" ); // set custom-rendezvous-server without port Config::set_option( "custom-rendezvous-server".to_string(), "127.0.0.1".to_string(), ); Config::set_option("relay-server".to_string(), "".to_string()); Config::set_option("api-server".to_string(), "".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server without port Config::set_option("relay-server".to_string(), "127.0.0.1".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with default port Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with custom port Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569"); // set custom-rendezvous-server without default port Config::set_option( "custom-rendezvous-server".to_string(), "127.0.0.1".to_string(), ); Config::set_option("relay-server".to_string(), "".to_string()); Config::set_option("api-server".to_string(), "".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server without port Config::set_option("relay-server".to_string(), "127.0.0.1".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with default port Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with custom port Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string()); assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118"); assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569"); // set custom-rendezvous-server with custom port Config::set_option( "custom-rendezvous-server".to_string(), "127.0.0.1:23456".to_string(), ); Config::set_option("relay-server".to_string(), "".to_string()); Config::set_option("api-server".to_string(), "".to_string()); assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:23457"), "ws://127.0.0.1:23459"); // set relay-server without port Config::set_option("relay-server".to_string(), "127.0.0.1".to_string()); assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with default port Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string()); assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119"); // set relay-server with custom port Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string()); assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458"); assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569"); } }