Compare commits

..

1 Commits

Author SHA1 Message Date
mike 0c49f9a29c proto: backport HttpProxyRequest/Response (tags 27, 28)
Backports the HeaderEntry / HttpProxyRequest / HttpProxyResponse messages
and the corresponding union tags from upstream rustdesk/hbb_common
@87b11a7 onto the OSS server's pinned commit @83419b6 — wire-compatible
with the client-side encoder in rustdesk's src/common.rs::tcp_proxy_request.

Cherry-pick over a full submodule bump because the upstream commit pulls
newer transitive deps (notably tokio) that risk breaking axum 0.5 in the
hbbs HTTP layer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-01 19:27:57 +02:00
18 changed files with 260 additions and 4730 deletions
+17 -50
View File
@@ -6,20 +6,15 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = []
webrtc = ["dep:webrtc"]
[dependencies]
# new flexi_logger failed on rustc 1.75
flexi_logger = { version = "0.27", features = ["async"] }
protobuf = { version = "3.7", features = ["with-bytes"] }
tokio = { version = "1.44", features = ["full"] }
protobuf = { version = "3.4", features = ["with-bytes"] }
tokio = { version = "1.38", features = ["full"] }
tokio-util = { version = "0.7", features = ["full"] }
futures = "0.3"
bytes = { version = "1.10", features = ["serde"] }
bytes = { version = "1.6", features = ["serde"] }
log = "0.4"
env_logger = "0.11"
env_logger = "0.10"
socket2 = { version = "0.3", features = ["reuseport"] }
zstd = "0.13"
anyhow = "1.0"
@@ -29,72 +24,44 @@ rand = "0.8"
serde_derive = "1.0"
serde = "1.0"
serde_json = "1.0"
lazy_static = "1.5"
lazy_static = "1.4"
confy = { git = "https://github.com/rustdesk-org/confy" }
dirs-next = "2.0"
filetime = "0.2"
sodiumoxide = "0.2"
regex = "1.11"
regex = "1.8"
tokio-socks = { git = "https://github.com/rustdesk-org/tokio-socks" }
chrono = "0.4"
backtrace = "0.3"
libc = "0.2"
dlopen = "0.1"
toml = "0.7"
uuid = { version = "1.16", features = ["v4"] }
uuid = { version = "1.3", features = ["v4"] }
# new sysinfo issue: https://github.com/rustdesk/rustdesk/pull/6330#issuecomment-2270871442
sysinfo = { git = "https://github.com/rustdesk-org/sysinfo", branch = "rlim_max" }
# new flexi_logger failed on nightly rustc 1.75 for x86
thiserror = "1.0"
httparse = "1.10"
httparse = "1.5"
base64 = "0.22"
url = "2.5"
url = "2.2"
sha2 = "0.10"
whoami = "1.5"
tokio-rustls = { version = "0.26", features = [
"logging",
"tls12",
"ring",
], default-features = false }
tokio-native-tls = "0.3"
tokio-tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
rustls-platform-verifier = "0.6"
rustls-pki-types = "1.11"
rustls-native-certs = "0.8"
webpki-roots = "1.0.4"
async-recursion = "1.1"
webrtc = { version = "0.14.0", optional = true }
libloading = "0.8"
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
mac_address = "1.1"
default_net = { git = "https://github.com/rustdesk-org/default_net" }
machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" }
[target.'cfg(not(any(target_os = "macos", target_os = "windows")))'.dependencies]
tokio-rustls = { version = "0.26", features = ["logging", "tls12", "ring"], default-features = false }
rustls-platform-verifier = "0.3.1"
rustls-pki-types = "1.4"
[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
tokio-native-tls ="0.3"
[build-dependencies]
protobuf-codegen = { version = "3.7" }
[dev-dependencies]
clap = "4.5.51"
webrtc = "0.14.0"
protobuf-codegen = { version = "3.4" }
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = [
"winuser",
"synchapi",
"pdh",
"memoryapi",
"sysinfoapi",
] }
winapi = { version = "0.3", features = ["winuser", "synchapi", "pdh", "memoryapi", "sysinfoapi"] }
[target.'cfg(target_os = "macos")'.dependencies]
osascript = "0.3"
[target.'cfg(target_os = "linux")'.dependencies]
sctk = { package = "smithay-client-toolkit", version = "0.20.0", default-features = false, features = [
"calloop",
] }
users = { version = "0.11" }
x11 = "2.21"
-154
View File
@@ -1,154 +0,0 @@
extern crate hbb_common;
#[cfg(feature = "webrtc")]
use hbb_common::webrtc::WebRTCStream;
use std::io::Write;
use anyhow::Result;
use bytes::Bytes;
use clap::{Arg, Command};
use tokio::time::Duration;
#[cfg(not(feature = "webrtc"))]
#[tokio::main]
async fn main() -> Result<()> {
println!(
"The webrtc feature is not enabled. \
Please enable the webrtc feature to run this example."
);
Ok(())
}
#[cfg(feature = "webrtc")]
#[tokio::main]
async fn main() -> Result<()> {
let app = Command::new("webrtc-stream")
.about("An example of webrtc stream using hbb_common and webrtc-rs")
.arg(
Arg::new("debug")
.long("debug")
.short('d')
.action(clap::ArgAction::SetTrue)
.help("Prints debug log information"),
)
.arg(
Arg::new("offer")
.long("offer")
.short('o')
.help("set offer from other endpoint"),
);
let matches = app.clone().get_matches();
let debug = matches.contains_id("debug");
if debug {
println!("Debug log enabled");
env_logger::Builder::new()
.format(|buf, record| {
writeln!(
buf,
"{}:{} [{}] {} - {}",
record.file().unwrap_or("unknown"),
record.line().unwrap_or(0),
record.level(),
chrono::Local::now().format("%H:%M:%S.%6f"),
record.args()
)
})
.filter(Some("hbb_common"), log::LevelFilter::Debug)
.init();
}
let remote_endpoint = if let Some(endpoint) = matches.get_one::<String>("offer") {
endpoint.to_string()
} else {
"".to_string()
};
let webrtc_stream = WebRTCStream::new(&remote_endpoint, false, 30000).await?;
// Print the offer to be sent to the other peer
let local_endpoint = webrtc_stream.get_local_endpoint().await?;
if remote_endpoint.is_empty() {
println!();
// Wait for the answer to be pasted
println!(
"Start new terminal run: \n{} \ncopy remote endpoint and paste here",
format!(
"cargo r --features webrtc --example webrtc -- --offer {}",
local_endpoint
)
);
// readline blocking
let line = std::io::stdin()
.lines()
.next()
.ok_or_else(|| anyhow::anyhow!("No input received"))??;
webrtc_stream.set_remote_endpoint(&line).await?;
} else {
println!(
"Copy local endpoint and paste to the other peer: \n{}",
local_endpoint
);
}
let s1 = webrtc_stream.clone();
tokio::spawn(async move {
let _ = read_loop(s1).await;
});
let s2 = webrtc_stream.clone();
tokio::spawn(async move {
let _ = write_loop(s2).await;
});
println!("Press ctrl-c to stop");
tokio::select! {
_ = tokio::signal::ctrl_c() => {
println!();
}
};
Ok(())
}
// read_loop shows how to read from the datachannel directly
#[cfg(feature = "webrtc")]
async fn read_loop(mut stream: WebRTCStream) -> Result<()> {
loop {
let Some(res) = stream.next().await else {
println!("WebRTC stream closed; Exit the read_loop");
return Ok(());
};
match res {
Err(e) => {
println!("WebRTC stream read error: {}; Exit the read_loop", e);
return Ok(());
}
Ok(data) => {
println!("Message from stream: {}", String::from_utf8(data.to_vec())?);
}
}
}
}
// write_loop shows how to write to the webrtc stream directly
#[cfg(feature = "webrtc")]
async fn write_loop(mut stream: WebRTCStream) -> Result<()> {
let mut result = Result::<()>::Ok(());
while result.is_ok() {
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
tokio::select! {
_ = timeout.as_mut() =>{
let message = webrtc::peer_connection::math_rand_alpha(15);
result = stream.send_bytes(Bytes::from(message.clone())).await;
println!("Sent '{message}' {}", result.is_ok());
}
};
}
println!("WebRTC stream write failed; Exit the write_loop");
Ok(())
}
+1 -110
View File
@@ -79,7 +79,6 @@ message LoginRequest {
FileTransfer file_transfer = 7;
PortForward port_forward = 8;
ViewCamera view_camera = 15;
Terminal terminal = 16;
}
bool video_ack_required = 9;
uint64 session_id = 10;
@@ -87,11 +86,6 @@ message LoginRequest {
OSLogin os_login = 12;
string my_platform = 13;
bytes hwid = 14;
string avatar = 17;
}
message Terminal {
string service_id = 1; // Service ID for reconnecting to existing session
}
message Auth2FA {
@@ -103,7 +97,6 @@ message ChatMessage { string text = 1; }
message Features {
bool privacy_mode = 1;
bool terminal = 2;
}
message CodecAbility {
@@ -438,11 +431,6 @@ message FileTransferDigest {
uint64 file_size = 4;
bool is_upload = 5;
bool is_identical = 6;
uint64 transferred_size = 7; // For resume. Indicates the size of the file already transferred
bool is_resume = 8; // For resume. Indicates if the transfer is a resume.
// `is_resume` can let the controlled side know whether to check the `.digest` file.
// When `is_resume` is false, `.digest` exists, the same file does not exist,
// the controlled side should not check `.digest`, it should confirm with a new transfer request.
}
message FileTransferBlock {
@@ -464,12 +452,6 @@ message FileTransferSendRequest {
string path = 2;
bool include_hidden = 3;
int32 file_num = 4;
enum FileType {
Generic = 0;
Printer = 1;
}
FileType file_type = 5;
}
message FileTransferSendConfirmRequest {
@@ -562,16 +544,6 @@ message CliprdrFileContentsResponse {
message CliprdrTryEmpty {
}
// Clipobard file message for audit.
message CliprdrFile {
string name = 1;
uint64 size = 2;
}
message CliprdrFiles {
repeated CliprdrFile files = 1;
}
message Cliprdr {
oneof union {
CliprdrMonitorReady ready = 1;
@@ -582,7 +554,6 @@ message Cliprdr {
CliprdrFileContentsRequest file_contents_request = 6;
CliprdrFileContentsResponse file_contents_response = 7;
CliprdrTryEmpty try_empty = 8;
CliprdrFiles files = 9;
}
}
@@ -635,7 +606,7 @@ message PermissionInfo {
Restart = 5;
Recording = 6;
BlockInput = 7;
PrivacyMode = 8;
Camera = 8;
}
Permission permission = 1;
@@ -694,8 +665,6 @@ message OptionMessage {
BoolOption follow_remote_cursor = 15;
BoolOption follow_remote_window = 16;
BoolOption disable_camera = 17;
BoolOption terminal_persistent = 18;
BoolOption show_my_cursor = 19;
}
message TestDelay {
@@ -874,80 +843,6 @@ message VoiceCallResponse {
int64 ack_timestamp = 3;
}
message ScreenshotRequest {
int32 display = 1;
// sid is the session id on the controlling side
// It is used to forward the message to the correct remote (session) window.
string sid = 2;
}
message ScreenshotResponse {
string sid = 1;
// empty if success
string msg = 2;
bytes data = 3;
}
// Terminal messages - standalone feature like FileAction
message OpenTerminal {
int32 terminal_id = 1; // 0 for default terminal
uint32 rows = 2;
uint32 cols = 3;
}
message ResizeTerminal {
int32 terminal_id = 1;
uint32 rows = 2;
uint32 cols = 3;
}
message TerminalData {
int32 terminal_id = 1;
bytes data = 2;
bool compressed = 3;
}
message CloseTerminal {
int32 terminal_id = 1;
}
message TerminalAction {
oneof union {
OpenTerminal open = 1;
TerminalData data = 2;
ResizeTerminal resize = 3;
CloseTerminal close = 4;
}
}
message TerminalOpened {
int32 terminal_id = 1;
bool success = 2;
string message = 3;
uint32 pid = 4;
string service_id = 5; // Service ID for persistent sessions
repeated int32 persistent_sessions = 6; // Used to restore the persistent sessions.
}
message TerminalClosed {
int32 terminal_id = 1;
int32 exit_code = 2;
}
message TerminalError {
int32 terminal_id = 1;
string message = 2;
}
message TerminalResponse {
oneof union {
TerminalOpened opened = 1;
TerminalData data = 2;
TerminalClosed closed = 3;
TerminalError error = 4;
}
}
message Message {
oneof union {
SignedId signed_id = 3;
@@ -976,9 +871,5 @@ message Message {
PointerDeviceEvent pointer_device_event = 26;
Auth2FA auth_2fa = 27;
MultiClipboards multi_clipboards = 28;
ScreenshotRequest screenshot_request = 29;
ScreenshotResponse screenshot_response= 30;
TerminalAction terminal_action = 31;
TerminalResponse terminal_response = 32;
}
}
+7 -44
View File
@@ -12,7 +12,6 @@ enum ConnType {
PORT_FORWARD = 2;
RDP = 3;
VIEW_CAMERA = 4;
TERMINAL = 5;
}
message RegisterPeerResponse { bool request_pk = 2; }
@@ -24,40 +23,12 @@ message PunchHoleRequest {
ConnType conn_type = 4;
string token = 5;
string version = 6;
int32 udp_port = 7;
bool force_relay = 8;
int32 upnp_port = 9;
bytes socket_addr_v6 = 10;
}
message ControlPermissions {
enum Permission {
keyboard = 0;
remote_printer = 1;
clipboard = 2;
file = 3;
audio = 4;
camera = 5;
terminal = 6;
tunnel = 7;
restart = 8;
recording = 9;
block_input = 10;
remote_modify = 11;
privacy_mode = 12;
}
uint64 permissions = 1;
}
message PunchHole {
message PunchHole {
bytes socket_addr = 1;
string relay_server = 2;
NatType nat_type = 3;
int32 udp_port = 4;
bool force_relay = 5;
int32 upnp_port = 6;
bytes socket_addr_v6 = 7;
ControlPermissions control_permissions = 8;
}
message TestNatRequest {
@@ -82,8 +53,6 @@ message PunchHoleSent {
string relay_server = 3;
NatType nat_type = 4;
string version = 5;
int32 upnp_port = 6;
bytes socket_addr_v6 = 7;
}
message RegisterPk {
@@ -91,7 +60,6 @@ message RegisterPk {
bytes uuid = 2;
bytes pk = 3;
string old_id = 4;
bool no_register_device = 5;
}
message RegisterPkResponse {
@@ -125,9 +93,6 @@ message PunchHoleResponse {
}
string other_failure = 7;
int32 feedback = 8;
bool is_udp = 9;
int32 upnp_port = 10;
bytes socket_addr_v6 = 11;
}
message ConfigUpdate {
@@ -144,7 +109,6 @@ message RequestRelay {
string licence_key = 6;
ConnType conn_type = 7;
string token = 8;
ControlPermissions control_permissions = 9;
}
message RelayResponse {
@@ -158,8 +122,6 @@ message RelayResponse {
string refuse_reason = 6;
string version = 7;
int32 feedback = 9;
bytes socket_addr_v6 = 10;
int32 upnp_port = 11;
}
message SoftwareUpdate { string url = 1; }
@@ -168,11 +130,9 @@ message SoftwareUpdate { string url = 1; }
// even some router has below connection error if we connect itself,
// { kind: Other, error: "could not resolve to any address" },
// so we request local address to connect.
message FetchLocalAddr {
bytes socket_addr = 1;
message FetchLocalAddr {
bytes socket_addr = 1;
string relay_server = 2;
bytes socket_addr_v6 = 3;
ControlPermissions control_permissions = 4;
}
message LocalAddr {
@@ -181,7 +141,6 @@ message LocalAddr {
string relay_server = 3;
string id = 4;
string version = 5;
bytes socket_addr_v6 = 6;
}
message PeerDiscovery {
@@ -211,6 +170,10 @@ message HealthCheck {
string token = 1;
}
// Backported from upstream rustdesk/hbb_common @ 87b11a7 so the OSS hbbs
// can implement the HTTP-over-rendezvous fallback the client uses when
// OPTION_USE_RAW_TCP_FOR_API=Y. Wire-compatible with the client; only the
// three message types and tags 27/28 are added.
message HeaderEntry {
string name = 1;
string value = 2;
+35 -743
View File
File diff suppressed because it is too large Load Diff
+106 -959
View File
File diff suppressed because it is too large Load Diff
+5 -89
View File
@@ -57,23 +57,8 @@ pub use toml;
pub use uuid;
pub mod fingerprint;
pub use flexi_logger;
pub mod stream;
pub mod websocket;
#[cfg(feature = "webrtc")]
pub mod webrtc;
#[cfg(any(target_os = "android", target_os = "ios"))]
pub use rustls_platform_verifier;
pub use stream::Stream;
pub use whoami;
pub mod tls;
pub mod verifier;
pub use async_recursion;
#[cfg(target_os = "linux")]
pub use users;
pub use libloading;
#[cfg(target_os = "linux")]
pub use x11;
pub type Stream = tcp::FramedStream;
pub type SessionID = uuid::Uuid;
#[inline]
@@ -312,65 +297,10 @@ pub fn get_exe_time() -> SystemTime {
})
}
/// Known cases where machine_uid::get() may fail:
/// - Windows shutdown: "The media is write protected. (os error 19)"
/// - macOS (hard to reproduce, reproduced at login screen): "No matching IOPlatformUUID in `ioreg -rd1 -c IOPlatformExpertDevice` command"
pub fn get_uuid() -> Vec<u8> {
#[cfg(not(any(target_os = "android", target_os = "ios")))]
{
use std::sync::atomic::{AtomicUsize, Ordering};
static CACHED_MACHINE_UID: std::sync::OnceLock<Vec<u8>> = std::sync::OnceLock::new();
// Throttle only applies to the fallback machine_uid::get() log below, not the Once::call_once retry logs.
static LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
// Only macOS needs retry logic here because:
// - macOS: in testing, only one failure occurred when reading at 50ms intervals, so retry helps
// - Windows: failures during shutdown are persistent, retrying is pointless
#[cfg(target_os = "macos")]
{
static INIT: std::sync::Once = std::sync::Once::new();
INIT.call_once(|| {
// Keep in sync with upstream handling:
// https://github.com/rustdesk/rustdesk/blob/85db6779828349b23ca3eba91cc7cd36c5337797/src/common.rs#L822
let username = whoami::username().trim_end_matches('\0').to_owned();
let max_retries = if username == "root" { 16 } else { 8 };
for i in 0..max_retries {
match machine_uid::get() {
Ok(id) => {
let _ = CACHED_MACHINE_UID.set(id.into());
return;
}
Err(e) => {
log::error!("Failed to get machine uid in macOS retry #{i}: {e}");
}
}
std::thread::sleep(std::time::Duration::from_millis(50));
}
});
}
if let Some(uid) = CACHED_MACHINE_UID.get() {
return uid.clone();
}
match machine_uid::get() {
Ok(id) => {
let uid: Vec<u8> = id.into();
let _ = CACHED_MACHINE_UID.set(uid.clone());
return uid;
}
Err(e) => {
if LOG_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |count| {
(count < 30).then_some(count + 1)
})
.is_ok()
{
log::error!("Failed to get machine uid: {e}");
}
}
}
if let Ok(id) = machine_uid::get() {
return id.into();
}
Config::get_key_pair().1
}
@@ -432,7 +362,7 @@ pub fn init_log(_is_async: bool, _name: &str) -> Option<flexi_logger::LoggerHand
#[cfg(debug_assertions)]
{
use env_logger::*;
init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info,reqwest=warn,rustls=warn,webrtc-sctp=warn,webrtc=warn"));
init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info"));
}
#[cfg(not(debug_assertions))]
{
@@ -447,7 +377,7 @@ pub fn init_log(_is_async: bool, _name: &str) -> Option<flexi_logger::LoggerHand
path.push(_name);
}
use flexi_logger::*;
if let Ok(x) = Logger::try_with_env_or_str("debug,reqwest=warn,rustls=warn,webrtc-sctp=warn,webrtc=warn") {
if let Ok(x) = Logger::try_with_env_or_str("debug") {
logger_holder = x
.log_to_file(FileSpec::default().directory(path))
.write_mode(if _is_async {
@@ -514,20 +444,6 @@ pub fn version_check_request(typ: String) -> (VersionCheckRequest, String) {
)
}
pub fn time_based_rand() -> u32 {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let mut x = nanos as u64;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
(x % 32768) as u32
}
#[cfg(test)]
mod test {
use super::*;
+9 -188
View File
@@ -3,7 +3,7 @@ use sodiumoxide::base64;
use std::sync::{Arc, RwLock};
lazy_static::lazy_static! {
pub static ref TEMPORARY_PASSWORD:Arc<RwLock<String>> = Arc::new(RwLock::new(get_auto_password()));
pub static ref TEMPORARY_PASSWORD:Arc<RwLock<String>> = Arc::new(RwLock::new(Config::get_auto_password(temporary_password_length())));
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -20,18 +20,9 @@ pub enum ApproveMode {
Click,
}
fn get_auto_password() -> String {
let len = temporary_password_length();
if Config::get_bool_option(crate::config::keys::OPTION_ALLOW_NUMERNIC_ONE_TIME_PASSWORD) {
Config::get_auto_numeric_password(len)
} else {
Config::get_auto_password(len)
}
}
// Should only be called in server
pub fn update_temporary_password() {
*TEMPORARY_PASSWORD.write().unwrap() = get_auto_password();
*TEMPORARY_PASSWORD.write().unwrap() = Config::get_auto_password(temporary_password_length());
}
// Should only be called in server
@@ -71,7 +62,7 @@ pub fn permanent_enabled() -> bool {
pub fn has_valid_password() -> bool {
temporary_enabled() && !temporary_password().is_empty()
|| permanent_enabled() && Config::has_permanent_password()
|| permanent_enabled() && !Config::get_permanent_password().is_empty()
}
pub fn approve_mode() -> ApproveMode {
@@ -93,27 +84,8 @@ pub fn hide_cm() -> bool {
const VERSION_LEN: usize = 2;
// Check if data is already encrypted by verifying:
// 1) version prefix "00"
// 2) valid base64 payload
// 3) decoded payload length >= secretbox::MACBYTES
//
// We intentionally avoid trying to decrypt here because key mismatch would cause
// false negatives.
// Reference: secretbox::seal returns ciphertext length = plaintext length + MACBYTES
// https://github.com/sodiumoxide/sodiumoxide/blob/3057acb1a030ad86ed8892a223d64036ab5e8523/src/crypto/secretbox/xsalsa20poly1305.rs#L67
fn is_encrypted(v: &[u8]) -> bool {
if v.len() <= VERSION_LEN || !v.starts_with(b"00") {
return false;
}
match base64::decode(&v[VERSION_LEN..], base64::Variant::Original) {
Ok(decoded) => decoded.len() >= sodiumoxide::crypto::secretbox::MACBYTES,
Err(_) => false,
}
}
pub fn encrypt_str_or_original(s: &str, version: &str, max_len: usize) -> String {
if is_encrypted(s.as_bytes()) {
if decrypt_str_or_original(s, version).1 {
log::error!("Duplicate encryption!");
return s.to_owned();
}
@@ -146,17 +118,11 @@ pub fn decrypt_str_or_original(s: &str, current_version: &str) -> (String, bool,
}
}
// For values that already look encrypted (version prefix + base64), avoid
// repeated store on each load when decryption fails.
(
s.to_owned(),
false,
!s.is_empty() && !is_encrypted(s.as_bytes()),
)
(s.to_owned(), false, !s.is_empty())
}
pub fn encrypt_vec_or_original(v: &[u8], version: &str, max_len: usize) -> Vec<u8> {
if is_encrypted(v) {
if decrypt_vec_or_original(v, version).1 {
log::error!("Duplicate encryption!");
return v.to_owned();
}
@@ -186,9 +152,7 @@ pub fn decrypt_vec_or_original(v: &[u8], current_version: &str) -> (Vec<u8>, boo
}
}
// For values that already look encrypted (version prefix + base64), avoid
// repeated store on each load when decryption fails.
(v.to_owned(), false, !v.is_empty() && !is_encrypted(v))
(v.to_owned(), false, !v.is_empty())
}
fn encrypt(v: &[u8]) -> Result<String, ()> {
@@ -211,8 +175,7 @@ pub fn symmetric_crypt(data: &[u8], encrypt: bool) -> Result<Vec<u8>, ()> {
use sodiumoxide::crypto::secretbox;
use std::convert::TryInto;
let uuid = crate::get_uuid();
let mut keybuf = uuid.clone();
let mut keybuf = crate::get_uuid();
keybuf.resize(secretbox::KEYBYTES, 0);
let key = secretbox::Key(keybuf.try_into().map_err(|_| ())?);
let nonce = secretbox::Nonce([0; secretbox::NONCEBYTES]);
@@ -220,21 +183,7 @@ pub fn symmetric_crypt(data: &[u8], encrypt: bool) -> Result<Vec<u8>, ()> {
if encrypt {
Ok(secretbox::seal(data, &nonce, &key))
} else {
let res = secretbox::open(data, &nonce, &key);
#[cfg(not(any(target_os = "android", target_os = "ios")))]
if res.is_err() {
// Fallback: try pk if uuid decryption failed (in case encryption used pk due to machine_uid failure)
if let Some(key_pair) = Config::get_existing_key_pair() {
let pk = key_pair.1;
if pk != uuid {
let mut keybuf = pk;
keybuf.resize(secretbox::KEYBYTES, 0);
let pk_key = secretbox::Key(keybuf.try_into().map_err(|_| ())?);
return secretbox::open(data, &nonce, &pk_key);
}
}
}
res
secretbox::open(data, &nonce, &key)
}
}
@@ -310,33 +259,6 @@ mod test {
let data: Vec<u8> = "1ü1111".as_bytes().to_vec();
assert_eq!(decrypt_vec_or_original(&data, version).0, data);
// Base64-shaped "00" prefixed values shorter than MACBYTES are treated
// as original/plain values and should be stored.
let data = "00YWJjZA==";
let (decrypted, succ, store) = decrypt_str_or_original(data, version);
assert_eq!(decrypted, data);
assert!(!succ);
assert!(store);
let data = b"00YWJjZA==".to_vec();
let (decrypted, succ, store) = decrypt_vec_or_original(&data, version);
assert_eq!(decrypted, data);
assert!(!succ);
assert!(store);
// When decoded length reaches MACBYTES, it is treated as encrypted-like
// and should not trigger repeated store.
let exact_mac = vec![0u8; sodiumoxide::crypto::secretbox::MACBYTES];
let exact_mac_b64 =
sodiumoxide::base64::encode(&exact_mac, sodiumoxide::base64::Variant::Original);
let data = format!("00{exact_mac_b64}");
let (_, succ, store) = decrypt_str_or_original(&data, version);
assert!(!succ);
assert!(!store);
let data = data.into_bytes();
let (_, succ, store) = decrypt_vec_or_original(&data, version);
assert!(!succ);
assert!(!store);
println!("test speed");
let test_speed = |len: usize, name: &str| {
let mut data: Vec<u8> = vec![];
@@ -370,105 +292,4 @@ mod test {
test_speed(10 * 1024 * 1024, "10M");
test_speed(100 * 1024 * 1024, "100M");
}
#[test]
fn test_is_encrypted() {
use super::*;
use sodiumoxide::base64::{encode, Variant};
use sodiumoxide::crypto::secretbox;
// Empty data should not be considered encrypted
assert!(!is_encrypted(b""));
assert!(!is_encrypted(b"0"));
assert!(!is_encrypted(b"00"));
// Data without "00" prefix should not be considered encrypted
assert!(!is_encrypted(b"01abcd"));
assert!(!is_encrypted(b"99abcd"));
assert!(!is_encrypted(b"hello world"));
// Data with "00" prefix but invalid base64 should not be considered encrypted
assert!(!is_encrypted(b"00!!!invalid base64!!!"));
assert!(!is_encrypted(b"00@#$%"));
// Data with "00" prefix and valid base64 but shorter than MACBYTES is not encrypted
assert!(!is_encrypted(b"00YWJjZA==")); // "abcd" in base64
assert!(!is_encrypted(b"00SGVsbG8gV29ybGQ=")); // "Hello World" in base64
// Data with "00" prefix and valid base64 with decoded len == MACBYTES is considered encrypted
let exact_mac = vec![0u8; secretbox::MACBYTES];
let exact_mac_b64 = encode(&exact_mac, Variant::Original);
let exact_mac_candidate = format!("00{exact_mac_b64}");
assert!(is_encrypted(exact_mac_candidate.as_bytes()));
// Real encrypted data should be detected
let version = "00";
let max_len = 128;
let encrypted_str = encrypt_str_or_original("1", version, max_len);
assert!(is_encrypted(encrypted_str.as_bytes()));
let encrypted_vec = encrypt_vec_or_original(b"1", version, max_len);
assert!(is_encrypted(&encrypted_vec));
// Original unencrypted data should not be detected as encrypted
assert!(!is_encrypted(b"1"));
assert!(!is_encrypted("1".as_bytes()));
}
#[test]
fn test_encrypted_payload_min_len_macbytes() {
use super::*;
use sodiumoxide::base64::{decode, Variant};
use sodiumoxide::crypto::secretbox;
let version = "00";
let max_len = 128;
let encrypted_str = encrypt_str_or_original("1", version, max_len);
let decoded = decode(&encrypted_str.as_bytes()[VERSION_LEN..], Variant::Original).unwrap();
assert!(
decoded.len() >= secretbox::MACBYTES,
"decoded encrypted payload must be at least MACBYTES"
);
let encrypted_vec = encrypt_vec_or_original(b"1", version, max_len);
let decoded = decode(&encrypted_vec[VERSION_LEN..], Variant::Original).unwrap();
assert!(
decoded.len() >= secretbox::MACBYTES,
"decoded encrypted payload must be at least MACBYTES"
);
}
// Test decryption fallback when data was encrypted with key_pair but decryption tries machine_uid first
#[test]
#[cfg(not(any(target_os = "android", target_os = "ios")))]
fn test_decrypt_with_pk_fallback() {
use sodiumoxide::crypto::secretbox;
use std::convert::TryInto;
let uuid = crate::get_uuid();
let pk = crate::config::Config::get_key_pair().1;
// Ensure uuid != pk, otherwise fallback branch won't be tested
if uuid == pk {
eprintln!("skip: uuid == pk, fallback branch won't be tested");
return;
}
let data = b"test password 123";
let nonce = secretbox::Nonce([0; secretbox::NONCEBYTES]);
// Encrypt with pk (simulating machine_uid failure during encryption)
let mut pk_keybuf = pk;
pk_keybuf.resize(secretbox::KEYBYTES, 0);
let pk_key = secretbox::Key(pk_keybuf.try_into().unwrap());
let encrypted = secretbox::seal(data, &nonce, &pk_key);
// Decrypt using symmetric_crypt (should fallback to pk since uuid differs)
let decrypted = super::symmetric_crypt(&encrypted, false);
assert!(
decrypted.is_ok(),
"Decryption with pk fallback should succeed"
);
assert_eq!(decrypted.unwrap(), data);
}
}
+8 -280
View File
@@ -1,51 +1,10 @@
use crate::ResultType;
use std::{
collections::HashMap,
path::{Path, PathBuf},
process::Command,
};
use users::{get_current_uid, get_user_by_uid, os::unix::UserExt};
use sctk::{
output::OutputData,
output::{OutputHandler, OutputState},
reexports::client::protocol::wl_output::WlOutput,
reexports::client::{globals, Proxy},
reexports::client::{Connection, QueueHandle},
registry::{ProvidesRegistryState, RegistryState},
};
use std::{collections::HashMap, process::Command};
lazy_static::lazy_static! {
pub static ref DISTRO: Distro = Distro::new();
}
// to-do: There seems to be some runtime issue that causes the audit logs to be generated.
// We may need to fix this and remove this workaround in the future.
//
// We use the pre-search method to find the command path to avoid the audit logs on some systems.
// No idea why the audit logs happen.
// Though the audit logs may disappear after rebooting.
//
// See https://github.com/rustdesk/rustdesk/discussions/11959
//
// `ausearch -x /usr/share/rustdesk/rustdesk` will return
// ...
// time->Tue Jun 24 10:40:43 2025
// type=PROCTITLE msg=audit(1750776043.446:192757): proctitle=2F7573722F62696E2F727573746465736B002D2D73657276696365
// type=PATH msg=audit(1750776043.446:192757): item=0 name="/usr/local/bin/sh" nametype=UNKNOWN cap_fp=0 cap_fi=0 cap_fe=0 cap_fver=0 cap_frootid=0
// type=CWD msg=audit(1750776043.446:192757): cwd="/"
// type=SYSCALL msg=audit(1750776043.446:192757): arch=c000003e syscall=59 success=no exit=-2 a0=7fb7dbd22da0 a1=1d65f2c0 a2=7ffc25193360 a3=7ffc25194ec0 items=1 ppid=172208 pid=267565 auid=4294967295 uid=0 gid=0 euid=0 suid=0 fsuid=0 egid=0 sgid=0 fsgid=0 tty=(none) ses=4294967295 comm="rustdesk" exe="/usr/share/rustdesk/rustdesk" subj=unconfined key="processos_criados"
// ----
// time->Tue Jun 24 10:40:43 2025
// type=PROCTITLE msg=audit(1750776043.446:192758): proctitle=2F7573722F62696E2F727573746465736B002D2D73657276696365
// type=PATH msg=audit(1750776043.446:192758): item=0 name="/usr/sbin/sh" nametype=UNKNOWN cap_fp=0 cap_fi=0 cap_fe=0 cap_fver=0 cap_frootid=0
// ...
lazy_static::lazy_static! {
pub static ref CMD_LOGINCTL: String = find_cmd_path("loginctl");
pub static ref CMD_PS: String = find_cmd_path("ps");
pub static ref CMD_SH: String = find_cmd_path("sh");
}
pub const DISPLAY_SERVER_WAYLAND: &str = "wayland";
pub const DISPLAY_SERVER_X11: &str = "x11";
pub const DISPLAY_DESKTOP_KDE: &str = "KDE";
@@ -73,25 +32,6 @@ impl Distro {
}
}
fn find_cmd_path(cmd: &'static str) -> String {
let test_cmd = format!("/bin/{}", cmd);
if std::path::Path::new(&test_cmd).exists() {
return test_cmd;
}
let test_cmd = format!("/usr/bin/{}", cmd);
if std::path::Path::new(&test_cmd).exists() {
return test_cmd;
}
if let Ok(output) = Command::new("which").arg(cmd).output() {
if output.status.success() {
return String::from_utf8_lossy(&output.stdout).trim().to_string();
}
}
cmd.to_string()
}
// Deprecated. Use `hbb_common::platform::linux::is_kde_session()` instead for now.
// Or we need to set the correct environment variable in the server process.
#[inline]
pub fn is_kde() -> bool {
if let Ok(env) = std::env::var(XDG_CURRENT_DESKTOP) {
@@ -101,21 +41,9 @@ pub fn is_kde() -> bool {
}
}
// Don't use `hbb_common::platform::linux::is_kde()` here.
// It's not correct in the server process.
pub fn is_kde_session() -> bool {
std::process::Command::new(CMD_SH.as_str())
.arg("-c")
.arg("pgrep -f kded[0-9]+")
.stdout(std::process::Stdio::piped())
.output()
.map(|o| !o.stdout.is_empty())
.unwrap_or(false)
}
#[inline]
pub fn is_gdm_user(username: &str) -> bool {
username == "gdm" || username == "sddm"
username == "gdm"
// || username == "lightgdm"
}
@@ -176,7 +104,7 @@ pub fn get_display_server_of_session(session: &str) -> String {
} else {
"".to_owned()
};
if display_server.is_empty() || display_server == "tty" || display_server == "unspecified" {
if display_server.is_empty() || display_server == "tty" {
if let Ok(sestype) = std::env::var("XDG_SESSION_TYPE") {
if !sestype.is_empty() {
return sestype.to_lowercase();
@@ -247,7 +175,7 @@ fn _get_values_of_seat0(indices: &[usize], ignore_gdm_wayland: bool) -> Vec<Stri
continue;
}
}
if d == "tty" || d == "unspecified" {
if d == "tty" {
continue;
}
return line_values(indices, line);
@@ -276,26 +204,17 @@ pub fn is_active_and_seat0(sid: &str) -> bool {
}
}
// Check both "Lock" and "Switch user"
pub fn is_session_locked(sid: &str) -> bool {
if let Ok(output) = run_loginctl(Some(vec!["show-session", sid, "--property=LockedHint"])) {
String::from_utf8_lossy(&output.stdout).contains("LockedHint=yes")
} else {
false
}
}
// **Note** that the return value here, the last character is '\n'.
// Use `run_cmds_trim_newline()` if you want to remove '\n' at the end.
pub fn run_cmds(cmds: &str) -> ResultType<String> {
let output = std::process::Command::new(CMD_SH.as_str())
let output = std::process::Command::new("sh")
.args(vec!["-c", cmds])
.output()?;
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
pub fn run_cmds_trim_newline(cmds: &str) -> ResultType<String> {
let output = std::process::Command::new(CMD_SH.as_str())
let output = std::process::Command::new("sh")
.args(vec!["-c", cmds])
.output()?;
let out = String::from_utf8_lossy(&output.stdout);
@@ -308,7 +227,7 @@ pub fn run_cmds_trim_newline(cmds: &str) -> ResultType<String> {
fn run_loginctl(args: Option<Vec<&str>>) -> std::io::Result<std::process::Output> {
if std::env::var("FLATPAK_ID").is_ok() {
let mut l_args = CMD_LOGINCTL.to_string();
let mut l_args = String::from("loginctl");
if let Some(a) = args.as_ref() {
l_args = format!("{} {}", l_args, a.join(" "));
}
@@ -319,7 +238,7 @@ fn run_loginctl(args: Option<Vec<&str>>) -> std::io::Result<std::process::Output
return res;
}
}
let mut cmd = std::process::Command::new(CMD_LOGINCTL.as_str());
let mut cmd = std::process::Command::new("loginctl");
if let Some(a) = args {
return cmd.args(a).output();
}
@@ -365,138 +284,6 @@ pub fn system_message(title: &str, msg: &str, forever: bool) -> ResultType<()> {
crate::bail!("failed to post system message");
}
#[derive(Debug, Clone)]
pub struct WaylandDisplayInfo {
pub name: String,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub logical_size: Option<(i32, i32)>,
pub refresh_rate: i32,
}
// Retrieves information about all connected displays via the Wayland protocol.
pub fn get_wayland_displays() -> ResultType<Vec<WaylandDisplayInfo>> {
struct WaylandEnv {
registry_state: RegistryState,
output_state: OutputState,
}
impl OutputHandler for WaylandEnv {
fn output_state(&mut self) -> &mut OutputState {
&mut self.output_state
}
fn new_output(&mut self, _: &Connection, _: &QueueHandle<Self>, _: WlOutput) {}
fn update_output(&mut self, _: &Connection, _: &QueueHandle<Self>, _: WlOutput) {}
fn output_destroyed(&mut self, _: &Connection, _: &QueueHandle<Self>, _: WlOutput) {}
}
impl ProvidesRegistryState for WaylandEnv {
fn registry(&mut self) -> &mut RegistryState {
&mut self.registry_state
}
sctk::registry_handlers!();
}
sctk::delegate_output!(WaylandEnv);
sctk::delegate_registry!(WaylandEnv);
let conn = Connection::connect_to_env()?;
let (globals, mut event_queue) = globals::registry_queue_init(&conn)?;
let queue_handle = event_queue.handle();
let registry_state = RegistryState::new(&globals);
let output_state = OutputState::new(&globals, &queue_handle);
let mut environment = WaylandEnv {
registry_state,
output_state,
};
event_queue.roundtrip(&mut environment)?;
let outputs: Vec<_> = environment.output_state.outputs().collect();
let mut display_infos = Vec::new();
for output in outputs {
if let Some(output_data) = output.data::<OutputData>() {
output_data.with_output_info(|info| {
if let Some(mode) = info.modes.iter().find(|m| m.current) {
let (x, y) = info.location;
let (width, height) = mode.dimensions;
let refresh_rate = mode.refresh_rate;
let name = info.name.clone().unwrap_or_default();
let logical_size = info.logical_size;
display_infos.push(WaylandDisplayInfo {
name,
x,
y,
width,
height,
logical_size,
refresh_rate,
});
}
});
}
}
Ok(display_infos)
}
/// Escape a string for safe use in shell commands by wrapping in single quotes.
///
/// This function handles the edge case of single quotes within the string by:
/// 1. Ending the current single-quoted section
/// 2. Adding an escaped single quote
/// 3. Starting a new single-quoted section
///
/// Example: "it's here" -> "'it'\''s here'"
#[inline]
pub fn shell_quote(s: &str) -> String {
format!("'{}'", s.replace("'", "'\\''"))
}
/// Get the current user's home directory via getpwuid (trusted source).
///
/// This function uses the system's password database (via `getpwuid`) to retrieve
/// the home directory, avoiding the security risk of relying on the `HOME`
/// environment variable which can be manipulated by untrusted input.
///
/// # Returns
/// - `Some(PathBuf)` if the home directory was found and exists
/// - `None` if the user lookup failed or the directory doesn't exist
///
/// # Security
/// This function is designed to be safe against confused-deputy attacks where
/// an attacker might manipulate environment variables to influence privileged
/// operations.
pub fn get_home_dir_trusted() -> Option<PathBuf> {
let uid = get_current_uid();
match get_user_by_uid(uid) {
Some(user) => {
let home = user.home_dir();
if Path::is_dir(home) {
Some(PathBuf::from(home))
} else {
log::warn!(
"Home directory for uid {} does not exist or is not a directory: {:?}",
uid,
home
);
None
}
}
None => {
log::warn!("Failed to get user info for uid {}", uid);
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -510,63 +297,4 @@ mod tests {
run_cmds("whoami").unwrap()
);
}
/// Test get_home_dir_trusted: returns valid path and ignores HOME env var
#[test]
fn test_get_home_dir_trusted() {
let original_home = std::env::var("HOME").ok();
// Set HOME to a fake/malicious path
std::env::set_var("HOME", "/tmp/fake_malicious_home");
let result = get_home_dir_trusted();
// Restore original HOME
match original_home {
Some(home) => std::env::set_var("HOME", home),
None => std::env::remove_var("HOME"),
}
// Verify: returns valid path that is NOT the fake HOME
if let Some(path) = result {
assert!(path.is_absolute(), "Path should be absolute: {:?}", path);
assert!(path.is_dir(), "Path should be a directory: {:?}", path);
assert_ne!(
path.to_string_lossy(),
"/tmp/fake_malicious_home",
"Should not use HOME env var"
);
}
}
/// Test shell_quote with normal strings
#[test]
fn test_shell_quote_normal() {
assert_eq!(shell_quote("hello"), "'hello'");
assert_eq!(shell_quote("/home/user"), "'/home/user'");
}
/// Test shell_quote with spaces
#[test]
fn test_shell_quote_spaces() {
assert_eq!(shell_quote("/home/my user/file"), "'/home/my user/file'");
assert_eq!(shell_quote("path with spaces"), "'path with spaces'");
}
/// Test shell_quote with single quotes (the tricky case)
#[test]
fn test_shell_quote_single_quotes() {
assert_eq!(shell_quote("it's"), "'it'\\''s'");
assert_eq!(shell_quote("don't stop"), "'don'\\''t stop'");
}
/// Test shell_quote with shell metacharacters
#[test]
fn test_shell_quote_metacharacters() {
// These should all be safely quoted
assert_eq!(shell_quote("test;rm -rf /"), "'test;rm -rf /'");
assert_eq!(shell_quote("$(whoami)"), "'$(whoami)'");
assert_eq!(shell_quote("`id`"), "'`id`'");
assert_eq!(shell_quote("a && b"), "'a && b'");
assert_eq!(shell_quote("a | b"), "'a | b'");
}
}
-1
View File
@@ -62,7 +62,6 @@ extern "C" fn breakdown_signal_handler(sig: i32) {
.ok();
}
unsafe {
#[allow(static_mut_refs)]
if let Some(callback) = &GLOBAL_CALLBACK {
callback()
}
+54 -209
View File
@@ -3,15 +3,16 @@ use std::{
net::{SocketAddr, ToSocketAddrs},
};
use anyhow::bail;
use async_recursion::async_recursion;
use base64::{engine::general_purpose, Engine};
use httparse::{Error as HttpParseError, Response, EMPTY_HEADER};
use log::info;
use thiserror::Error as ThisError;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
#[cfg(any(target_os = "windows", target_os = "macos"))]
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
use tokio_rustls::{client::TlsStream as RustlsTlsStream, TlsConnector as RustlsTlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, TargetAddr};
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
use tokio_rustls::{client::TlsStream, TlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr};
use tokio_util::codec::Framed;
use url::Url;
@@ -19,7 +20,6 @@ use crate::{
bytes_codec::BytesCodec,
config::Socks5Server,
tcp::{DynTcpStream, FramedStream},
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
@@ -45,6 +45,7 @@ pub enum ProxyError {
HttpCode200(u16),
#[error("The proxy address resolution failed: {0}")]
AddressResolutionFailed(String),
#[cfg(any(target_os = "windows", target_os = "macos"))]
#[error("The native tls error: {0}")]
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
}
@@ -55,6 +56,7 @@ const MAXIMUM_RESPONSE_HEADERS: usize = 16;
const DEFINE_TIME_OUT: u64 = 600;
pub trait IntoUrl {
// Besides parsing as a valid `Url`, the `Url` must be a valid
// `http::Uri`, in that it makes sense to use in a network request.
fn into_url(self) -> Result<Url, ProxyError>;
@@ -126,14 +128,6 @@ impl Auth {
let authorization = format!("{}:{}", &self.user_name, &self.password);
general_purpose::STANDARD.encode(authorization.as_bytes())
}
pub fn username(&self) -> &str {
&self.user_name
}
pub fn password(&self) -> &str {
&self.password
}
}
#[derive(Clone)]
@@ -225,7 +219,7 @@ impl ProxyScheme {
Ok(scheme)
}
pub async fn socket_addrs(&self) -> Result<SocketAddr, ProxyError> {
log::trace!("Resolving socket address");
info!("Resolving socket address");
match self {
ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await,
ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await,
@@ -355,50 +349,37 @@ impl Proxy {
self
}
async fn new_stream(
&self,
local: SocketAddr,
proxy: SocketAddr,
) -> ResultType<tokio::net::TcpStream> {
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
Ok(stream)
}
pub async fn connect<'t, T>(
&self,
self,
target: T,
local_addr: Option<SocketAddr>,
) -> ResultType<FramedStream>
where
T: IntoTargetAddr<'t>,
{
log::trace!("Connect to proxy server");
info!("Connect to proxy server");
let proxy = self.proxy_addrs().await?;
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
let local = if let Some(addr) = local_addr {
addr
} else {
crate::config::Config::get_any_listen_addr(proxy.is_ipv4())
};
let stream = self.new_stream(local, proxy).await?;
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
let addr = stream.local_addr()?;
return match self.intercept {
ProxyScheme::Http { .. } => {
log::trace!("Connect to remote http proxy server: {}", proxy);
info!("Connect to remote http proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.http_connect(stream, &target_addr))
.await??;
super::timeout(self.ms_timeout, self.http_connect(stream, target)).await??;
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
@@ -407,54 +388,24 @@ impl Proxy {
))
}
ProxyScheme::Https { .. } => {
log::trace!("Connect to remote https proxy server: {}", proxy);
let url = format!("https://{}", self.intercept.get_host_and_port()?);
let tls_type = get_cached_tls_type(&url);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
let stream = match tls_type.unwrap_or(TlsType::Rustls) {
TlsType::Rustls => {
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
Some(stream),
&target_addr,
tls_type.is_some(),
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await?
}
TlsType::NativeTls => {
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
&target_addr,
danger_accept_invalid_cert,
)
.await?
}
_ => {
// Unreachable
crate::bail!("Unreachable, TlsType::Plain in HTTPS proxy");
}
};
info!("Connect to remote https proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.https_connect(stream, target)).await??;
Ok(FramedStream(
Framed::new(stream, BytesCodec::new()),
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
None,
0,
))
}
ProxyScheme::Socks5 { .. } => {
log::trace!("Connect to remote socket5 proxy server: {}", proxy);
info!("Connect to remote socket5 proxy server: {}", proxy);
let stream = if let Some(auth) = self.intercept.maybe_auth() {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_password_and_socket(
stream,
target_addr,
target,
&auth.user_name,
&auth.password,
),
@@ -463,7 +414,7 @@ impl Proxy {
} else {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_socket(stream, target_addr),
Socks5Stream::connect_with_socket(stream, target),
)
.await??
};
@@ -477,166 +428,57 @@ impl Proxy {
};
}
async fn https_connect_nativetls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = self.new_stream(local, proxy).await?;
let s = super::timeout(
self.ms_timeout,
self.https_connect_nativetls(
stream,
&target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await??;
upsert_tls_cache(
url,
TlsType::NativeTls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
pub async fn https_connect_nativetls<'a, Input>(
&self,
#[cfg(any(target_os = "windows", target_os = "macos"))]
pub async fn https_connect<'a, Input, T>(
self,
io: Input,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
target: T,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
let mut tls_connector_builder = native_tls::TlsConnector::builder();
if danger_accept_invalid_cert {
tls_connector_builder.danger_accept_invalid_certs(true);
}
let tls_connector = TlsConnector::from(tls_connector_builder.build()?);
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
let stream = tls_connector
.connect(&self.intercept.get_domain()?, io)
.await?;
self.http_connect(stream, target_addr).await
self.http_connect(stream, target).await
}
#[async_recursion]
async fn https_connect_rustls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
stream: Option<tokio::net::TcpStream>,
target_addr: &TargetAddr<'a>,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
origin_danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = stream.unwrap_or(self.new_stream(local, proxy).await?);
match super::timeout(
self.ms_timeout,
self.https_connect_rustls(
stream,
target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await?
{
Ok(s) => {
upsert_tls_cache(
&url,
TlsType::Rustls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
Err(e) => {
// NOTE: Maybe it's better to check if the error is related to TLS here. (ProxyError::IoError(e), or ProxyError::NativeTlsError(e))
// But we can only get the error when the TLS protocol is TLSv1.1.
// The error message of the following is unclear:
// https://github.com/rustdesk/rustdesk-server-pro/issues/189#issuecomment-1895701480
// So we just try to fallback unconditionally here.
//
// If the protocol is TLS 1.1, the error is:
// 1. "IO Error: received fatal alert: ProtocolVersion"
// 2. "IO Error: An existing connection was forcibly closed by the remote host. (os error 10054)" on Windows sometimes.
//
// If the cert verification fails, the error is:
// "IO Error: invalid peer certificate: UnknownIssuer"
let s = if danger_accept_invalid_cert.is_none() {
log::warn!(
"Falling back to rustls-tls (accept invalid cert) for HTTPS proxy server."
);
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
None,
target_addr,
is_tls_type_cached,
Some(true),
origin_danger_accept_invalid_cert,
)
.await?
} else if !is_tls_type_cached {
log::warn!("Falling back to native-tls for HTTPS proxy server.");
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
&target_addr,
origin_danger_accept_invalid_cert,
)
.await?
} else {
log::error!(
"Failed to connect to HTTPS proxy server with native-tls: {:?}.",
e
);
bail!(e)
};
Ok(s)
}
}
}
pub async fn https_connect_rustls<'a, Input>(
&self,
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
pub async fn https_connect<'a, Input, T>(
self,
io: Input,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<RustlsTlsStream<Input>>, ProxyError>
target: T,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
use std::convert::TryFrom;
let verifier = rustls_platform_verifier::tls_config();
let url_domain = self.intercept.get_domain()?;
let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str())
.map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))?
.to_owned();
let client_config = crate::verifier::client_config(danger_accept_invalid_cert)
.map_err(|e| ProxyError::IoError(std::io::Error::other(e)))?;
let tls_connector = RustlsTlsConnector::from(std::sync::Arc::new(client_config));
let tls_connector = TlsConnector::from(std::sync::Arc::new(verifier));
let stream = tls_connector.connect(domain, io).await?;
self.http_connect(stream, target_addr).await
self.http_connect(stream, target).await
}
pub async fn http_connect<'a, Input>(
&self,
pub async fn http_connect<'a, Input, T>(
self,
io: Input,
target_addr: &TargetAddr<'a>,
target: T,
) -> Result<BufStream<Input>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
let mut stream = BufStream::new(io);
let (domain, port) = get_domain_and_port(target_addr)?;
let (domain, port) = get_domain_and_port(target)?;
let request = self.make_request(&domain, port);
stream.write_all(request.as_bytes()).await?;
@@ -661,10 +503,13 @@ impl Proxy {
}
}
fn get_domain_and_port<'a>(target_addr: &TargetAddr<'a>) -> Result<(String, u16), ProxyError> {
fn get_domain_and_port<'a, T: IntoTargetAddr<'a>>(target: T) -> Result<(String, u16), ProxyError> {
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
match target_addr {
tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), *port)),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), port)),
}
}
+12 -69
View File
@@ -1,15 +1,12 @@
#[cfg(feature = "webrtc")]
use crate::webrtc::{self, is_webrtc_endpoint};
use crate::{
config::{Config, NetworkType},
tcp::FramedStream,
udp::FramedSocket,
websocket::{self, check_ws, is_ws_endpoint},
ResultType, Stream,
ResultType,
};
use anyhow::Context;
use std::{net::SocketAddr, sync::Arc};
use tokio::net::{ToSocketAddrs, UdpSocket};
use std::net::SocketAddr;
use tokio::net::ToSocketAddrs;
use tokio_socks::{IntoTargetAddr, TargetAddr};
#[inline]
@@ -52,30 +49,6 @@ pub fn increase_port<T: std::string::ToString>(host: T, offset: i32) -> String {
host
}
pub fn split_host_port<T: std::string::ToString>(host: T) -> Option<(String, i32)> {
let host = host.to_string();
if crate::is_ipv6_str(&host) {
if host.starts_with('[') {
let tmp: Vec<&str> = host.split("]:").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return Some((format!("{}]", tmp[0]), port));
}
}
}
} else if host.contains(':') {
let tmp: Vec<&str> = host.split(':').collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return Some((tmp[0].to_string(), port));
}
}
}
None
}
pub fn test_if_valid_server(host: &str, test_with_proxy: bool) -> String {
let host = check_port(host, 0);
use std::net::ToSocketAddrs;
@@ -122,7 +95,6 @@ impl IsResolvedSocketAddr for &str {
}
}
// This function checks if the target is a websocket endpoint and connects accordingly.
#[inline]
pub async fn connect_tcp<
't,
@@ -130,23 +102,10 @@ pub async fn connect_tcp<
>(
target: T,
ms_timeout: u64,
) -> ResultType<crate::Stream> {
#[cfg(feature = "webrtc")]
if is_webrtc_endpoint(&target.to_string()) {
return Ok(Stream::WebRTC(
webrtc::WebRTCStream::new(&target.to_string(), false, ms_timeout).await?,
));
}
let target_str = check_ws(&target.to_string());
if is_ws_endpoint(&target_str) {
return Ok(Stream::WebSocket(
websocket::WsFramedStream::new(target_str, None, None, ms_timeout).await?,
));
}
) -> ResultType<FramedStream> {
connect_tcp_local(target, None, ms_timeout).await
}
// This function connects directly to the target without checking for websocket endpoints.
pub async fn connect_tcp_local<
't,
T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display,
@@ -154,27 +113,19 @@ pub async fn connect_tcp_local<
target: T,
local: Option<SocketAddr>,
ms_timeout: u64,
) -> ResultType<Stream> {
) -> ResultType<FramedStream> {
if let Some(conf) = Config::get_socks() {
return Ok(Stream::Tcp(
FramedStream::connect(target, local, &conf, ms_timeout).await?,
));
return FramedStream::connect(target, local, &conf, ms_timeout).await;
}
if let Some(target_addr) = target.resolve() {
if let Some(local_addr) = local {
if local_addr.is_ipv6() && target_addr.is_ipv4() {
let resolved_target = query_nip_io(target_addr).await?;
return Ok(Stream::Tcp(
FramedStream::new(resolved_target, Some(local_addr), ms_timeout).await?,
));
if let Some(target) = target.resolve() {
if let Some(local) = local {
if local.is_ipv6() && target.is_ipv4() {
let target = query_nip_io(target).await?;
return FramedStream::new(target, Some(local), ms_timeout).await;
}
}
}
Ok(Stream::Tcp(
FramedStream::new(target, local, ms_timeout).await?,
))
FramedStream::new(target, local, ms_timeout).await
}
#[inline]
@@ -215,14 +166,6 @@ async fn test_target(target: &str) -> ResultType<SocketAddr> {
.context(format!("Failed to look up host for {target}"))
}
#[inline]
pub async fn new_direct_udp_for(target: &str) -> ResultType<(Arc<UdpSocket>, SocketAddr)> {
let peer_addr = test_target(target).await?;
let local_addr = Config::get_any_listen_addr(peer_addr.is_ipv4());
let socket = UdpSocket::bind(local_addr).await?;
Ok((Arc::new(socket), peer_addr))
}
#[inline]
pub async fn new_udp_for(
target: &str,
-149
View File
@@ -1,149 +0,0 @@
use crate::{config, tcp, websocket, ResultType};
#[cfg(feature = "webrtc")]
use crate::webrtc;
use sodiumoxide::crypto::secretbox::Key;
use std::net::SocketAddr;
use tokio::net::TcpStream;
// support Websocket and tcp.
pub enum Stream {
#[cfg(feature = "webrtc")]
WebRTC(webrtc::WebRTCStream),
WebSocket(websocket::WsFramedStream),
Tcp(tcp::FramedStream),
}
impl Stream {
#[inline]
pub fn set_send_timeout(&mut self, ms: u64) {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.set_send_timeout(ms),
Stream::WebSocket(s) => s.set_send_timeout(ms),
Stream::Tcp(s) => s.set_send_timeout(ms),
}
}
#[inline]
pub fn set_raw(&mut self) {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.set_raw(),
Stream::WebSocket(s) => s.set_raw(),
Stream::Tcp(s) => s.set_raw(),
}
}
#[inline]
pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.send_bytes(bytes).await,
Stream::WebSocket(s) => s.send_bytes(bytes).await,
Stream::Tcp(s) => s.send_bytes(bytes).await,
}
}
#[inline]
pub async fn send_raw(&mut self, bytes: Vec<u8>) -> ResultType<()> {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.send_raw(bytes).await,
Stream::WebSocket(s) => s.send_raw(bytes).await,
Stream::Tcp(s) => s.send_raw(bytes).await,
}
}
#[inline]
pub fn set_key(&mut self, key: Key) {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.set_key(key),
Stream::WebSocket(s) => s.set_key(key),
Stream::Tcp(s) => s.set_key(key),
}
}
#[inline]
pub fn is_secured(&self) -> bool {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.is_secured(),
Stream::WebSocket(s) => s.is_secured(),
Stream::Tcp(s) => s.is_secured(),
}
}
#[inline]
pub async fn next_timeout(
&mut self,
timeout: u64,
) -> Option<Result<bytes::BytesMut, std::io::Error>> {
match self {
#[cfg(feature = "webrtc")]
Stream::WebRTC(s) => s.next_timeout(timeout).await,
Stream::WebSocket(s) => s.next_timeout(timeout).await,
Stream::Tcp(s) => s.next_timeout(timeout).await,
}
}
/// establish connect from websocket
#[inline]
pub async fn connect_websocket(
url: impl AsRef<str>,
local_addr: Option<SocketAddr>,
proxy_conf: Option<&config::Socks5Server>,
timeout_ms: u64,
) -> ResultType<Self> {
let ws_stream =
websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?;
log::debug!("WebSocket connection established");
Ok(Self::WebSocket(ws_stream))
}
/// send message
#[inline]
pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> {
match self {
#[cfg(feature = "webrtc")]
Self::WebRTC(s) => s.send(msg).await,
Self::WebSocket(ws) => ws.send(msg).await,
Self::Tcp(tcp) => tcp.send(msg).await,
}
}
/// receive message
#[inline]
pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> {
match self {
#[cfg(feature = "webrtc")]
Self::WebRTC(s) => s.next().await,
Self::WebSocket(ws) => ws.next().await,
Self::Tcp(tcp) => tcp.next().await,
}
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
match self {
#[cfg(feature = "webrtc")]
Self::WebRTC(s) => s.local_addr(),
Self::WebSocket(ws) => ws.local_addr(),
Self::Tcp(tcp) => tcp.local_addr(),
}
}
#[inline]
pub fn from(stream: TcpStream, stream_addr: SocketAddr) -> Self {
Self::Tcp(tcp::FramedStream::from(stream, stream_addr))
}
#[inline]
#[cfg(feature = "webrtc")]
pub fn get_webrtc_stream(&self) -> Option<webrtc::WebRTCStream> {
match self {
Self::WebRTC(s) => Some(s.clone()),
_ => None,
}
}
}
+6 -6
View File
@@ -22,16 +22,16 @@ use tokio_socks::IntoTargetAddr;
use tokio_util::codec::Framed;
pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {}
pub struct DynTcpStream(pub Box<dyn TcpStreamTrait + Send + Sync>);
pub struct DynTcpStream(pub(crate) Box<dyn TcpStreamTrait + Send + Sync>);
#[derive(Clone)]
pub struct Encrypt(pub Key, pub u64, pub u64);
pub struct Encrypt(Key, u64, u64);
pub struct FramedStream(
pub Framed<DynTcpStream, BytesCodec>,
pub SocketAddr,
pub Option<Encrypt>,
pub u64,
pub(crate) Framed<DynTcpStream, BytesCodec>,
pub(crate) SocketAddr,
pub(crate) Option<Encrypt>,
pub(crate) u64,
);
impl Deref for FramedStream {
-121
View File
@@ -1,121 +0,0 @@
use std::{collections::HashMap, sync::RwLock};
use crate::config::allow_insecure_tls_fallback;
#[derive(Debug, Clone, Copy)]
pub enum TlsType {
Plain,
NativeTls,
Rustls,
}
lazy_static::lazy_static! {
static ref URL_TLS_TYPE: RwLock<HashMap<String, TlsType>> = RwLock::new(HashMap::new());
static ref URL_TLS_DANGER_ACCEPT_INVALID_CERTS: RwLock<HashMap<String, bool>> = RwLock::new(HashMap::new());
}
#[inline]
pub fn is_plain(url: &str) -> bool {
url.starts_with("ws://") || url.starts_with("http://")
}
// Extract domain from URL.
// e.g., "https://example.com/path" -> "example.com"
// "https://example.com:8080/path" -> "example.com:8080"
// See the tests for more examples.
#[inline]
fn get_domain_and_port_from_url(url: &str) -> &str {
// Remove scheme (e.g., http://, https://, ws://, wss://)
let scheme_end = url.find("://").map(|pos| pos + 3).unwrap_or(0);
let url2 = &url[scheme_end..];
// If userinfo is present, domain is after last '@'
let after_at = match url2.rfind('@') {
Some(pos) => &url2[pos + 1..],
None => url2,
};
// Find the end of domain (before '/' or '?')
let domain_end = after_at.find(&['/', '?'][..]).unwrap_or(after_at.len());
&after_at[..domain_end]
}
#[inline]
pub fn upsert_tls_cache(url: &str, tls_type: TlsType, danger_accept_invalid_cert: bool) {
if is_plain(url) {
return;
}
let domain_port = get_domain_and_port_from_url(url);
// Use curly braces to ensure the lock is released immediately.
{
URL_TLS_TYPE
.write()
.unwrap()
.insert(domain_port.to_string(), tls_type);
}
{
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
.write()
.unwrap()
.insert(domain_port.to_string(), danger_accept_invalid_cert);
}
}
#[inline]
pub fn reset_tls_cache() {
// Use curly braces to ensure the lock is released immediately.
{
URL_TLS_TYPE.write().unwrap().clear();
}
{
URL_TLS_DANGER_ACCEPT_INVALID_CERTS.write().unwrap().clear();
}
}
#[inline]
pub fn get_cached_tls_type(url: &str) -> Option<TlsType> {
if is_plain(url) {
return Some(TlsType::Plain);
}
let domain_port = get_domain_and_port_from_url(url);
URL_TLS_TYPE.read().unwrap().get(domain_port).cloned()
}
#[inline]
pub fn get_cached_tls_accept_invalid_cert(url: &str) -> Option<bool> {
if !allow_insecure_tls_fallback() {
return Some(false);
}
if is_plain(url) {
return Some(false);
}
let domain_port = get_domain_and_port_from_url(url);
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
.read()
.unwrap()
.get(domain_port)
.cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_domain_and_port_from_url() {
for (url, expected_domain_port) in vec![
("http://example.com", "example.com"),
("https://example.com", "example.com"),
("ws://example.com/path", "example.com"),
("wss://example.com:8080/path", "example.com:8080"),
("https://user:pass@example.com", "example.com"),
("https://example.com?query=param", "example.com"),
("https://example.com:8443?query=param", "example.com:8443"),
("ftp://example.com/resource", "example.com"), // ftp scheme
("example.com/path", "example.com"), // no scheme
("example.com:8080/path", "example.com:8080"),
] {
let domain_port = get_domain_and_port_from_url(url);
assert_eq!(domain_port, expected_domain_port);
}
}
}
-257
View File
@@ -1,257 +0,0 @@
use crate::ResultType;
use rustls_pki_types::{ServerName, UnixTime};
use std::sync::Arc;
use tokio_rustls::rustls::{self, client::WebPkiServerVerifier, ClientConfig};
use tokio_rustls::rustls::{
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
DigitallySignedStruct, Error as TLSError, SignatureScheme,
};
// https://github.com/seanmonstar/reqwest/blob/fd61bc93e6f936454ce0b978c6f282f06eee9287/src/tls.rs#L608
#[derive(Debug)]
pub(crate) struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, TLSError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
/// A certificate verifier that tries a primary verifier first,
/// and falls back to a platform verifier if the primary fails.
#[cfg(any(target_os = "android", target_os = "ios"))]
#[derive(Debug)]
struct FallbackPlatformVerifier {
primary: Arc<dyn ServerCertVerifier>,
fallback: Arc<dyn ServerCertVerifier>,
}
#[cfg(any(target_os = "android", target_os = "ios"))]
impl FallbackPlatformVerifier {
fn with_platform_fallback(
primary: Arc<dyn ServerCertVerifier>,
provider: Arc<rustls::crypto::CryptoProvider>,
) -> Result<Self, TLSError> {
#[cfg(target_os = "android")]
if !crate::config::ANDROID_RUSTLS_PLATFORM_VERIFIER_INITIALIZED
.load(std::sync::atomic::Ordering::Relaxed)
{
return Err(TLSError::General(
"rustls-platform-verifier not initialized".to_string(),
));
}
let fallback = Arc::new(rustls_platform_verifier::Verifier::new(provider)?);
Ok(Self { primary, fallback })
}
}
#[cfg(any(target_os = "android", target_os = "ios"))]
impl ServerCertVerifier for FallbackPlatformVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls_pki_types::CertificateDer<'_>,
intermediates: &[rustls_pki_types::CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, TLSError> {
match self.primary.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Ok(verified) => Ok(verified),
Err(primary_err) => {
match self.fallback.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Ok(verified) => Ok(verified),
Err(fallback_err) => {
log::error!(
"Both primary and fallback verifiers failed to verify server certificate, primary error: {:?}, fallback error: {:?}",
primary_err,
fallback_err
);
Err(primary_err)
}
}
}
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
// Both WebPkiServerVerifier and rustls_platform_verifier use the same signature verification implementation.
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L278
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/crypto/mod.rs#L17
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/android.rs#L9
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/apple.rs#L6
self.primary.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
// Same implementation as verify_tls12_signature.
self.primary.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
// Both WebPkiServerVerifier and rustls_platform_verifier use the same crypto provider,
// so their supported signature schemes are identical.
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L172C52-L172C85
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/android.rs#L327
// https://github.com/rustls/rustls-platform-verifier/blob/1099f161bfc5e3ac7f90aad88b1bf788e72906cb/rustls-platform-verifier/src/verification/apple.rs#L304
self.primary.supported_verify_schemes()
}
}
fn webpki_server_verifier(
provider: Arc<rustls::crypto::CryptoProvider>,
) -> ResultType<Arc<WebPkiServerVerifier>> {
// Load root certificates from both bundled webpki_roots and system-native certificate stores.
// This approach is consistent with how reqwest and tokio-tungstenite handle root certificates.
// https://github.com/snapview/tokio-tungstenite/blob/35d110c24c9d030d1608ec964d70c789dfb27452/src/tls.rs#L95
// https://github.com/seanmonstar/reqwest/blob/b126ca49da7897e5d676639cdbf67a0f6838b586/src/async_impl/client.rs#L643
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let rustls_native_certs::CertificateResult { certs, errors, .. } =
rustls_native_certs::load_native_certs();
if !errors.is_empty() {
log::warn!("native root CA certificate loading errors: {errors:?}");
}
root_cert_store.add_parsable_certificates(certs);
// Build verifier using with_root_certificates behavior (WebPkiServerVerifier without CRLs).
// Both reqwest and tokio-tungstenite use this approach.
// https://github.com/seanmonstar/reqwest/blob/b126ca49da7897e5d676639cdbf67a0f6838b586/src/async_impl/client.rs#L749
// https://github.com/snapview/tokio-tungstenite/blob/35d110c24c9d030d1608ec964d70c789dfb27452/src/tls.rs#L127
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/client/builder.rs#L47
// with_root_certificates creates a WebPkiServerVerifier without revocation checking:
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L177
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L168
// Since no CRL is provided (as is the case here), we must explicitly set allow_unknown_revocation_status()
// to match the behavior of with_root_certificates, which allows unknown revocation status by default.
// https://github.com/rustls/rustls/blob/1ee126adb3352a2dcd72420dcd6040351a6ddc1e/rustls/src/webpki/server_verifier.rs#L37
// Note: build() only returns an error if the root certificate store is empty, which won't happen here.
let verifier = rustls::client::WebPkiServerVerifier::builder_with_provider(
Arc::new(root_cert_store),
provider.clone(),
)
.allow_unknown_revocation_status()
.build()
.map_err(|e| anyhow::anyhow!(e))?;
Ok(verifier)
}
pub fn client_config(danger_accept_invalid_cert: bool) -> ResultType<ClientConfig> {
if danger_accept_invalid_cert {
client_config_danger()
} else {
client_config_safe()
}
}
pub fn client_config_safe() -> ResultType<ClientConfig> {
// Use the default builder which uses the default protocol versions and crypto provider.
// The with_protocol_versions API has been removed in rustls master branch:
// https://github.com/rustls/rustls/pull/2599
// This approach is consistent with tokio-tungstenite's usage:
// https://github.com/snapview/tokio-tungstenite/blob/35d110c24c9d030d1608ec964d70c789dfb27452/src/tls.rs#L126
let config_builder = rustls::ClientConfig::builder();
let provider = config_builder.crypto_provider().clone();
let webpki_verifier = webpki_server_verifier(provider.clone())?;
#[cfg(any(target_os = "android", target_os = "ios"))]
{
match FallbackPlatformVerifier::with_platform_fallback(webpki_verifier.clone(), provider) {
Ok(fallback_verifier) => {
let config = config_builder
.dangerous()
.with_custom_certificate_verifier(Arc::new(fallback_verifier))
.with_no_client_auth();
Ok(config)
}
Err(e) => {
log::error!(
"Failed to create fallback verifier: {:?}, use webpki verifier instead",
e
);
let config = config_builder
.with_webpki_verifier(webpki_verifier)
.with_no_client_auth();
Ok(config)
}
}
}
#[cfg(not(any(target_os = "android", target_os = "ios")))]
{
let config = config_builder
.with_webpki_verifier(webpki_verifier)
.with_no_client_auth();
Ok(config)
}
}
pub fn client_config_danger() -> ResultType<ClientConfig> {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
Ok(config)
}
-770
View File
@@ -1,770 +0,0 @@
use std::collections::HashMap;
use std::io::{Error, ErrorKind};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use webrtc::api::setting_engine::SettingEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::RTCDataChannel;
use webrtc::ice::mdns::MulticastDnsMode;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::peer_connection::RTCPeerConnection;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine;
use bytes::{Bytes, BytesMut};
use tokio::sync::watch;
use tokio::sync::Mutex;
use tokio::time::timeout;
use url::Url;
use crate::config;
use crate::protobuf::Message;
use crate::sodiumoxide::crypto::secretbox::Key;
use crate::ResultType;
pub struct WebRTCStream {
pc: Arc<RTCPeerConnection>,
stream: Arc<Mutex<Arc<RTCDataChannel>>>,
state_notify: watch::Receiver<bool>,
send_timeout: u64,
}
/// Standard maximum message size for WebRTC data channels (RFC 8831, 65535 bytes).
/// Most browsers, including Chromium, enforce this protocol limit.
const DATA_CHANNEL_BUFFER_SIZE: u16 = u16::MAX;
// use 3 public STUN servers to find out the NAT type, 2 must be the same address but different ports
// https://stackoverflow.com/questions/72805316/determine-nat-mapping-behaviour-using-two-stun-servers
// luckily nextcloud supports two ports for STUN
// unluckily webrtc-rs does not use the same port to do the STUN request
static DEFAULT_ICE_SERVERS: [&str; 3] = [
"stun:stun.cloudflare.com:3478",
"stun:stun.nextcloud.com:3478",
"stun:stun.nextcloud.com:443",
];
lazy_static::lazy_static! {
static ref SESSIONS: Arc::<Mutex<HashMap<String, WebRTCStream>>> = Default::default();
}
impl Clone for WebRTCStream {
fn clone(&self) -> Self {
WebRTCStream {
pc: self.pc.clone(),
stream: self.stream.clone(),
state_notify: self.state_notify.clone(),
send_timeout: self.send_timeout,
}
}
}
impl WebRTCStream {
#[inline]
fn get_remote_offer(endpoint: &str) -> ResultType<String> {
// Ensure the endpoint starts with the "webrtc://" prefix
if !endpoint.starts_with("webrtc://") {
return Err(
Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into(),
);
}
// Extract the Base64-encoded SDP part
let encoded_sdp = &endpoint["webrtc://".len()..];
// Decode the Base64 string
let decoded_bytes = BASE64_STANDARD
.decode(encoded_sdp)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP"))?;
Ok(String::from_utf8(decoded_bytes).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"Failed to convert decoded bytes to UTF-8",
)
})?)
}
#[inline]
fn sdp_to_endpoint(sdp: &str) -> String {
let encoded_sdp = BASE64_STANDARD.encode(sdp);
format!("webrtc://{}", encoded_sdp)
}
#[inline]
fn get_key_for_sdp(sdp: &RTCSessionDescription) -> ResultType<String> {
let binding = sdp.unmarshal()?;
let Some(fingerprint) = binding.attribute("fingerprint") else {
// find fingerprint attribute in media descriptions
for media in &binding.media_descriptions {
if media.media_name.media != "application" {
continue;
}
if let Some(fp) = media
.attributes
.iter()
.find(|x| x.key == "fingerprint")
.and_then(|x| x.value.clone())
{
return Ok(fp);
}
}
return Err(anyhow::anyhow!("SDP fingerprint attribute not found"));
};
Ok(fingerprint.to_string())
}
#[inline]
fn get_key_for_sdp_json(sdp_json: &str) -> ResultType<String> {
if sdp_json.is_empty() {
return Ok("".to_string());
}
let sdp = serde_json::from_str::<RTCSessionDescription>(&sdp_json)?;
Self::get_key_for_sdp(&sdp)
}
#[inline]
async fn get_key_for_peer(pc: &Arc<RTCPeerConnection>, is_local: bool) -> ResultType<String> {
let Some(desc) = (match is_local {
true => pc.local_description().await,
false => pc.remote_description().await,
}) else {
return Err(anyhow::anyhow!("PeerConnection description is not set"));
};
Self::get_key_for_sdp(&desc)
}
#[inline]
fn get_ice_server_from_url(url: &str) -> Option<RTCIceServer> {
// standard url format with turn scheme: turn://user:pass@host:port
match Url::parse(url) {
Ok(u) => {
if u.scheme() == "turn"
|| u.scheme() == "turns"
|| u.scheme() == "stun"
|| u.scheme() == "stuns"
{
Some(RTCIceServer {
urls: vec![format!(
"{}:{}:{}",
u.scheme(),
u.host_str().unwrap_or_default(),
u.port().unwrap_or(3478)
)],
username: u.username().to_string(),
credential: u.password().unwrap_or_default().to_string(),
..Default::default()
})
} else {
None
}
}
Err(_) => None,
}
}
#[inline]
fn get_ice_servers() -> Vec<RTCIceServer> {
let mut ice_servers = Vec::new();
let cfg = config::Config::get_option(config::keys::OPTION_ICE_SERVERS);
let mut has_stun = false;
for url in cfg.split(',').map(str::trim) {
if let Some(ice_server) = Self::get_ice_server_from_url(url) {
// Detect STUN in user config
if ice_server
.urls
.iter()
.any(|u| u.starts_with("stun:") || u.starts_with("stuns:"))
{
has_stun = true;
}
ice_servers.push(ice_server);
}
}
// If there is no STUN (either TURN-only or empty config) → prepend defaults
if !has_stun {
ice_servers.insert(
0,
RTCIceServer {
urls: DEFAULT_ICE_SERVERS.iter().map(|s| s.to_string()).collect(),
..Default::default()
},
);
}
ice_servers
}
pub async fn new(
remote_endpoint: &str,
force_relay: bool,
ms_timeout: u64,
) -> ResultType<Self> {
log::debug!("New webrtc stream to endpoint: {}", remote_endpoint);
let remote_offer = if remote_endpoint.is_empty() {
"".into()
} else {
Self::get_remote_offer(remote_endpoint)?
};
let mut key = Self::get_key_for_sdp_json(&remote_offer)?;
let sessions_lock = SESSIONS.lock().await;
if let Some(cached_stream) = sessions_lock.get(&key) {
if !key.is_empty() {
log::debug!("Start webrtc with cached peer");
return Ok(cached_stream.clone());
}
}
drop(sessions_lock);
let start_local_offer = remote_offer.is_empty();
// Create a SettingEngine and enable Detach
let mut s = SettingEngine::default();
s.detach_data_channels();
s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled);
// Create the API object
let api = APIBuilder::new().with_setting_engine(s).build();
// Prepare the configuration, get ICE servers from config
let config = RTCConfiguration {
ice_servers: Self::get_ice_servers(),
ice_transport_policy: if force_relay {
RTCIceTransportPolicy::Relay
} else {
RTCIceTransportPolicy::All
},
..Default::default()
};
let (notify_tx, notify_rx) = watch::channel(false);
// Create a new RTCPeerConnection
let pc = Arc::new(api.new_peer_connection(config).await?);
let bootstrap_dc = if start_local_offer {
let dc_open_notify = notify_tx.clone();
// Create a data channel with label "bootstrap"
let dc = pc.create_data_channel("bootstrap", None).await?;
dc.on_open(Box::new(move || {
log::debug!("Local data channel bootstrap open.");
let _ = dc_open_notify.send(true);
Box::pin(async {})
}));
dc
} else {
// Wait for the data channel to be created by the remote peer
// Here we create a dummy data channel to satisfy the type system
Arc::new(RTCDataChannel::default())
};
let stream = Arc::new(Mutex::new(bootstrap_dc));
if !start_local_offer {
// Register data channel creation handling
let dc_open_notify = notify_tx.clone();
let stream_for_dc = stream.clone();
pc.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
let d_label = dc.label().to_owned();
let dc_open_notify2 = dc_open_notify.clone();
let stream_for_dc_clone = stream_for_dc.clone();
log::debug!("Remote data channel {} ready", d_label);
Box::pin(async move {
let mut stream_lock = stream_for_dc_clone.lock().await;
*stream_lock = dc.clone();
drop(stream_lock);
dc.on_open(Box::new(move || {
let _ = dc_open_notify2.send(true);
Box::pin(async {})
}));
})
}));
}
// This will notify you when the peer has connected/disconnected
let stream_for_close = stream.clone();
let pc_for_close = pc.clone();
pc.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
let stream_for_close2 = stream_for_close.clone();
let on_connection_notify = notify_tx.clone();
let pc_for_close2 = pc_for_close.clone();
Box::pin(async move {
log::debug!("WebRTC session peer connection state: {}", s);
match s {
RTCPeerConnectionState::Disconnected
| RTCPeerConnectionState::Failed
| RTCPeerConnectionState::Closed => {
let _ = on_connection_notify.send(true);
log::debug!("WebRTC session closing due to disconnected");
let _ = stream_for_close2.lock().await.close().await;
log::debug!("WebRTC session stream closed");
let mut sessions_lock = SESSIONS.lock().await;
match Self::get_key_for_peer(&pc_for_close2, start_local_offer).await {
Ok(k) => {
sessions_lock.remove(&k);
log::debug!("WebRTC session removed key: {}", k);
}
Err(e) => {
log::error!(
"Failed to extract key for peer during session cleanup: {:?}",
e
);
// Fallback: try to remove any session associated with this peer connection
let keys_to_remove: Vec<String> = sessions_lock
.iter()
.filter_map(|(key, session)| {
if Arc::ptr_eq(&session.pc, &pc_for_close2) {
Some(key.clone())
} else {
None
}
})
.collect();
for k in keys_to_remove {
sessions_lock.remove(&k);
log::debug!("WebRTC session removed by fallback key: {}", k);
}
}
}
}
_ => {}
}
})
}));
// process offer/answer
if start_local_offer {
let sdp = pc.create_offer(None).await?;
let mut gather_complete = pc.gathering_complete_promise().await;
pc.set_local_description(sdp.clone()).await?;
let _ = gather_complete.recv().await;
log::debug!("local offer:\n{}", sdp.sdp);
// get local sdp key
key = Self::get_key_for_sdp(&sdp)?;
log::debug!("Start webrtc with local key: {}", key);
} else {
let sdp = serde_json::from_str::<RTCSessionDescription>(&remote_offer)?;
pc.set_remote_description(sdp.clone()).await?;
let answer = pc.create_answer(None).await?;
let mut gather_complete = pc.gathering_complete_promise().await;
pc.set_local_description(answer).await?;
let _ = gather_complete.recv().await;
log::debug!("remote offer:\n{}", sdp.sdp);
// get remote sdp key
key = Self::get_key_for_sdp(&sdp)?;
log::debug!("Start webrtc with remote key: {}", key);
}
let mut final_lock = SESSIONS.lock().await;
if let Some(session) = final_lock.get(&key) {
pc.close().await.ok();
return Ok(session.clone());
}
let webrtc_stream = Self {
pc,
stream,
state_notify: notify_rx,
send_timeout: ms_timeout,
};
final_lock.insert(key, webrtc_stream.clone());
Ok(webrtc_stream)
}
#[inline]
pub async fn get_local_endpoint(&self) -> ResultType<String> {
if let Some(local_desc) = self.pc.local_description().await {
let sdp = serde_json::to_string(&local_desc)?;
let endpoint = Self::sdp_to_endpoint(&sdp);
Ok(endpoint)
} else {
Err(anyhow::anyhow!("Local desc is not set"))
}
}
#[inline]
pub async fn set_remote_endpoint(&self, endpoint: &str) -> ResultType<()> {
let offer = Self::get_remote_offer(endpoint)?;
log::debug!("WebRTC set remote sdp: {}", offer);
let sdp = serde_json::from_str::<RTCSessionDescription>(&offer)?;
self.pc.set_remote_description(sdp).await?;
Ok(())
}
#[inline]
pub fn set_raw(&mut self) {
// not-supported
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
}
#[inline]
pub fn set_send_timeout(&mut self, ms: u64) {
self.send_timeout = ms;
}
#[inline]
pub fn set_key(&mut self, _key: Key) {
// not-supported
// WebRTC uses built-in DTLS encryption for secure communication.
// DTLS handles key exchange and encryption automatically, so explicit key management is not required.
}
#[inline]
pub fn is_secured(&self) -> bool {
true
}
#[inline]
pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
self.send_raw(msg.write_to_bytes()?).await
}
#[inline]
pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
self.send_bytes(Bytes::from(msg)).await
}
#[inline]
async fn wait_for_connect_result(&mut self) {
if *self.state_notify.borrow() {
return;
}
let _ = self.state_notify.changed().await;
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
if self.send_timeout > 0 {
match timeout(
Duration::from_millis(self.send_timeout),
self.wait_for_connect_result(),
)
.await
{
Ok(_) => {}
Err(_) => {
self.pc.close().await.ok();
return Err(Error::new(
ErrorKind::TimedOut,
"WebRTC send wait for connect timeout",
)
.into());
}
}
} else {
self.wait_for_connect_result().await;
}
let stream = self.stream.lock().await.clone();
stream.send(&bytes).await?;
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
self.wait_for_connect_result().await;
let stream = self.stream.lock().await.clone();
// TODO reuse buffer?
let mut buffer = BytesMut::zeroed(DATA_CHANNEL_BUFFER_SIZE as usize);
let dc = stream.detach().await.ok()?;
let n = match dc.read(&mut buffer).await {
Ok(n) => n,
Err(err) => {
self.pc.close().await.ok();
return Some(Err(Error::new(
ErrorKind::Other,
format!("data channel read error: {}", err),
)));
}
};
if n == 0 {
self.pc.close().await.ok();
return Some(Err(Error::new(
ErrorKind::Other,
"data channel read exited with 0 bytes",
)));
}
buffer.truncate(n);
Some(Ok(buffer))
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
match timeout(Duration::from_millis(ms), self.next()).await {
Ok(res) => res,
Err(_) => None,
}
}
}
pub fn is_webrtc_endpoint(endpoint: &str) -> bool {
// use sdp base64 json string as endpoint, or prefix webrtc:
endpoint.starts_with("webrtc://")
}
#[cfg(test)]
mod tests {
use crate::config;
use crate::webrtc::WebRTCStream;
use crate::webrtc::DEFAULT_ICE_SERVERS;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
#[test]
fn test_webrtc_ice_url() {
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://example.com:3478")
.unwrap_or_default()
.urls[0],
"turn:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://example.com")
.unwrap_or_default()
.urls[0],
"turn:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://123@example.com")
.unwrap_or_default()
.username,
"123"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://123@example.com")
.unwrap_or_default()
.credential,
""
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("turn://123:321@example.com")
.unwrap_or_default()
.credential,
"321"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("stun://example.com:3478")
.unwrap_or_default()
.urls[0],
"stun:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_server_from_url("http://123:123@example.com:3478"),
None
);
config::Config::set_option("ice-servers".to_string(), "".to_string());
assert_eq!(
WebRTCStream::get_ice_servers()[0].urls[0],
DEFAULT_ICE_SERVERS[0].to_string()
);
config::Config::set_option(
"ice-servers".to_string(),
",stun://example.com,turn://example.com,sdf".to_string(),
);
assert_eq!(
WebRTCStream::get_ice_servers()[0].urls[0],
"stun:example.com:3478"
);
assert_eq!(
WebRTCStream::get_ice_servers()[1].urls[0],
"turn:example.com:3478"
);
assert_eq!(WebRTCStream::get_ice_servers().len(), 2);
config::Config::set_option(
"ice-servers".to_string(),
"".to_string(),
);
}
#[test]
fn test_webrtc_session_key() {
let mut sdp_str = "".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
)
.unwrap_or_default(),
""
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=fingerprint:sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88
a=group:BUNDLE 0
a=extmap-allow-mixed
m=application 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
).unwrap_or_default(),
"sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88"
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=group:BUNDLE 0
a=extmap-allow-mixed
m=application 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=fingerprint:sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
).unwrap_or_default(),
"sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88"
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=group:BUNDLE 0
a=extmap-allow-mixed
m=application 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT"
.to_owned();
assert!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
)
.is_err(),
"can not find fingerprint attribute"
);
sdp_str = "\
v=0
o=- 7400546379179479477 208696200 IN IP4 0.0.0.0
s=-
t=0 0
a=group:BUNDLE 0
a=extmap-allow-mixed
m=audio 9 UDP/DTLS/SCTP webrtc-datachannel
c=IN IP4 0.0.0.0
a=fingerprint:sha-256 97:52:D6:1F:1E:87:6C:DA:B8:21:95:64:A5:85:89:FA:02:71:C7:4D:B3:FD:25:92:40:FB:6B:65:24:3C:79:88
a=setup:actpass
a=mid:0
a=sendrecv
a=sctp-port:5000
a=ice-ufrag:RMWjjpXfpXbDPdMz
a=ice-pwd:BtIqlWHfwhsJdFiBROeLuEbNmYfHxRfT".to_owned();
assert!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer(sdp_str).unwrap_or_default()
)
.is_err(),
"can not find datachannel fingerprint attribute"
);
assert!(
WebRTCStream::get_key_for_sdp(
&RTCSessionDescription::offer("".to_owned()).unwrap_or_default()
)
.is_err(),
"invalid sdp should error"
);
assert!(
WebRTCStream::get_key_for_sdp_json("{}").is_err(),
"empty sdp json should error"
);
assert!(
WebRTCStream::get_key_for_sdp_json("{ss}").is_err(),
"invalid sdp json should error"
);
let endpoint = "webrtc://eyJ0eXBlIjoiYW5zd2VyIiwic2RwIjoidj0wXHJcbm89LSA0MTA1NDk3NTY2NDgyMTQzODEwIDYwMzk1NzQw\
MCBJTiBJUDQgMC4wLjAuMFxyXG5zPS1cclxudD0wIDBcclxuYT1maW5nZXJwcmludDpzaGEtMjU2IDYxOjYwOjc0OjQwOjI4OkNFOjBCOjBDOjc1OjRCOj\
EwOjlBOkVFOjc3OkY1OjQ0OjU3Ojg0OjUxOkRCOjA0OjkyOjRBOjEwOjFDOjRFOjVGOjdFOkYxOkIzOjcxOjIyXHJcbmE9Z3JvdXA6QlVORExFIDBcclxu\
YT1leHRtYXAtYWxsb3ctbWl4ZWRcclxubT1hcHBsaWNhdGlvbiA5IFVEUC9EVExTL1NDVFAgd2VicnRjLWRhdGFjaGFubmVsXHJcbmM9SU4gSVA0IDAuMC\
4wLjBcclxuYT1zZXR1cDphY3RpdmVcclxuYT1taWQ6MFxyXG5hPXNlbmRyZWN2XHJcbmE9c2N0cC1wb3J0OjUwMDBcclxuYT1pY2UtdWZyYWc6SHlnU1Rr\
V2RsRlpHRG1XWlxyXG5hPWljZS1wd2Q6SkJneFZWaGZveVhHdHZha1VWcnBQeHVOSVpMU3llS1pcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDEgdWRwID\
IxMzA3MDY0MzEgMTkyLjE2OC4xLjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDIgdWRwIDIxMzA3MDY0MzEgMTkyLjE2OC4x\
LjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6MTg2MTA0NTE5MCAxIHVkcCAxNjk0NDk4ODE1IDE0LjIxMi42OC4xMiAyNzAwNCB0eXAgc3JmbH\
ggcmFkZHIgMC4wLjAuMCBycG9ydCA2NDAwOFxyXG5hPWNhbmRpZGF0ZToxODYxMDQ1MTkwIDIgdWRwIDE2OTQ0OTg4MTUgMTQuMjEyLjY4LjEyIDI3MDA0\
IHR5cCBzcmZseCByYWRkciAwLjAuMC4wIHJwb3J0IDY0MDA4XHJcbmE9ZW5kLW9mLWNhbmRpZGF0ZXNcclxuIn0=".to_owned();
assert_eq!(
WebRTCStream::get_key_for_sdp_json(
&WebRTCStream::get_remote_offer(&endpoint).unwrap_or_default()
).unwrap_or_default(),
"sha-256 61:60:74:40:28:CE:0B:0C:75:4B:10:9A:EE:77:F5:44:57:84:51:DB:04:92:4A:10:1C:4E:5F:7E:F1:B3:71:22"
);
}
#[tokio::test]
async fn test_webrtc_new_stream() {
let mut endpoint = "webrtc://sdfsdf".to_owned();
assert!(
WebRTCStream::new(&endpoint, false, 10000).await.is_err(),
"invalid webrtc endpoint should error"
);
endpoint = "wss://sdfsdf".to_owned();
assert!(
WebRTCStream::new(&endpoint, false, 10000).await.is_err(),
"invalid webrtc endpoint should error"
);
assert!(
WebRTCStream::new("", false, 10000).await.is_ok(),
"local webrtc endpoint should ok"
);
endpoint = "webrtc://eyJ0eXBlIjoiYW5zd2VyIiwic2RwIjoidj0wXHJcbm89LSA0MTA1NDk3NTY2NDgyMTQzODEwIDYwMzk1NzQw\
MCBJTiBJUDQgMC4wLjAuMFxyXG5zPS1cclxudD0wIDBcclxuYT1maW5nZXJwcmludDpzaGEtMjU2IDYxOjYwOjc0OjQwOjI4OkNFOjBCOjBDOjc1OjRCOj\
EwOjlBOkVFOjc3OkY1OjQ0OjU3Ojg0OjUxOkRCOjA0OjkyOjRBOjEwOjFDOjRFOjVGOjdFOkYxOkIzOjcxOjIyXHJcbmE9Z3JvdXA6QlVORExFIDBcclxu\
YT1leHRtYXAtYWxsb3ctbWl4ZWRcclxubT1hcHBsaWNhdGlvbiA5IFVEUC9EVExTL1NDVFAgd2VicnRjLWRhdGFjaGFubmVsXHJcbmM9SU4gSVA0IDAuMC\
4wLjBcclxuYT1zZXR1cDphY3RpdmVcclxuYT1taWQ6MFxyXG5hPXNlbmRyZWN2XHJcbmE9c2N0cC1wb3J0OjUwMDBcclxuYT1pY2UtdWZyYWc6SHlnU1Rr\
V2RsRlpHRG1XWlxyXG5hPWljZS1wd2Q6SkJneFZWaGZveVhHdHZha1VWcnBQeHVOSVpMU3llS1pcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDEgdWRwID\
IxMzA3MDY0MzEgMTkyLjE2OC4xLjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6OTYzOTg4MzQ4IDIgdWRwIDIxMzA3MDY0MzEgMTkyLjE2OC4x\
LjIgNjQwMDcgdHlwIGhvc3RcclxuYT1jYW5kaWRhdGU6MTg2MTA0NTE5MCAxIHVkcCAxNjk0NDk4ODE1IDE0LjIxMi42OC4xMiAyNzAwNCB0eXAgc3JmbH\
ggcmFkZHIgMC4wLjAuMCBycG9ydCA2NDAwOFxyXG5hPWNhbmRpZGF0ZToxODYxMDQ1MTkwIDIgdWRwIDE2OTQ0OTg4MTUgMTQuMjEyLjY4LjEyIDI3MDA0\
IHR5cCBzcmZseCByYWRkciAwLjAuMC4wIHJwb3J0IDY0MDA4XHJcbmE9ZW5kLW9mLWNhbmRpZGF0ZXNcclxuIn0=".to_owned();
assert!(
WebRTCStream::new(&endpoint, false, 10000).await.is_err(),
"connect to an 'answer' webrtc endpoint should error"
);
}
}
-531
View File
@@ -1,531 +0,0 @@
use crate::{
config::{
keys::OPTION_RELAY_SERVER, use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT,
},
protobuf::Message,
socket_client::split_host_port,
sodiumoxide::crypto::secretbox::Key,
tcp::Encrypt,
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
use anyhow::bail;
use async_recursion::async_recursion;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use std::{
io::{Error, ErrorKind},
net::SocketAddr,
sync::Arc,
time::Duration,
};
use tokio::{net::TcpStream, time::timeout};
use tokio_native_tls::native_tls::TlsConnector;
use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite::protocol::Message as WsMessage, Connector,
MaybeTlsStream, WebSocketStream,
};
use tungstenite::client::IntoClientRequest;
use tungstenite::protocol::Role;
pub struct WsFramedStream {
stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
addr: SocketAddr,
encrypt: Option<Encrypt>,
send_timeout: u64,
}
impl WsFramedStream {
#[inline]
fn get_connector(
tls_type: &TlsType,
danger_accept_invalid_certs: bool,
) -> ResultType<Option<Connector>> {
match tls_type {
TlsType::Plain => Ok(Some(Connector::Plain)),
TlsType::NativeTls => {
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(danger_accept_invalid_certs)
.build()?;
Ok(Some(Connector::NativeTls(connector)))
}
TlsType::Rustls => {
let connector = match crate::verifier::client_config(danger_accept_invalid_certs) {
Ok(client_config) => Some(Connector::Rustls(Arc::new(client_config))),
Err(e) => {
log::warn!(
"Failed to get client config: {:?}, fallback to default connector",
e
);
None
}
};
Ok(connector)
}
}
}
async fn connect(
url: &str,
ms_timeout: u64,
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
// to-do: websocket proxy.
let tls_type = get_cached_tls_type(url);
let is_tls_type_cached = tls_type.is_some();
let tls_type = tls_type.unwrap_or(TlsType::Rustls);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await
}
#[async_recursion]
async fn try_connect(
url: &str,
ms_timeout: u64,
tls_type: TlsType,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
original_danger_accept_invalid_certs: Option<bool>,
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let ws_config = None;
let disable_nagle = false;
let request = url
.into_client_request()
.map_err(|e| Error::new(ErrorKind::Other, e))?;
let connector =
Self::get_connector(&tls_type, danger_accept_invalid_cert.unwrap_or(false))?;
match timeout(
Duration::from_millis(ms_timeout),
connect_async_tls_with_config(request, ws_config, disable_nagle, connector),
)
.await?
{
Ok((ws_stream, _)) => {
upsert_tls_cache(url, tls_type, danger_accept_invalid_cert.unwrap_or(false));
Ok(ws_stream)
}
Err(e) => match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) {
(TlsType::Rustls, _, None) => {
log::warn!(
"WebSocket connection with rustls-tls failed, try accept invalid certs: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
Some(true),
original_danger_accept_invalid_certs,
)
.await
}
(TlsType::Rustls, false, Some(_)) => {
log::warn!(
"WebSocket connection with rustls-tls failed, try native-tls: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
TlsType::NativeTls,
is_tls_type_cached,
original_danger_accept_invalid_certs,
original_danger_accept_invalid_certs,
)
.await
}
(TlsType::NativeTls, _, None) => {
log::warn!(
"WebSocket connection with native-tls failed, try accept invalid certs: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
Some(true),
original_danger_accept_invalid_certs,
)
.await
}
_ => {
log::error!(
"WebSocket connection failed with tls_type {:?}: {}, {:?}",
tls_type,
url,
e
);
bail!(e)
}
},
}
}
pub async fn new<T: AsRef<str>>(
url: T,
_local_addr: Option<SocketAddr>,
_proxy_conf: Option<&Socks5Server>,
ms_timeout: u64,
) -> ResultType<Self> {
let stream = Self::connect(url.as_ref(), ms_timeout).await?;
let addr = match stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
MaybeTlsStream::NativeTls(tls) => tls.get_ref().get_ref().get_ref().peer_addr()?,
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};
let ws = Self {
stream,
addr,
encrypt: None,
send_timeout: ms_timeout,
};
Ok(ws)
}
#[inline]
pub fn set_raw(&mut self) {
self.encrypt = None;
}
#[inline]
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
let ws_stream =
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
.await;
Ok(Self {
stream: ws_stream,
addr,
encrypt: None,
send_timeout: 0,
})
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
self.addr
}
#[inline]
pub fn set_send_timeout(&mut self, ms: u64) {
self.send_timeout = ms;
}
#[inline]
pub fn set_key(&mut self, key: Key) {
self.encrypt = Some(Encrypt::new(key));
}
#[inline]
pub fn is_secured(&self) -> bool {
self.encrypt.is_some()
}
#[inline]
pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
self.send_raw(msg.write_to_bytes()?).await
}
#[inline]
pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
let mut msg = msg;
if let Some(key) = self.encrypt.as_mut() {
msg = key.enc(&msg);
}
self.send_bytes(Bytes::from(msg)).await
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
let msg = WsMessage::Binary(bytes);
if self.send_timeout > 0 {
timeout(
Duration::from_millis(self.send_timeout),
self.stream.send(msg),
)
.await??
} else {
self.stream.send(msg).await?
};
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
while let Some(msg) = self.stream.next().await {
let msg = match msg {
Ok(msg) => msg,
Err(e) => {
log::error!("{}", e);
return Some(Err(Error::new(
ErrorKind::Other,
format!("WebSocket protocol error: {}", e),
)));
}
};
match msg {
WsMessage::Binary(data) => {
let mut bytes = BytesMut::from(&data[..]);
if let Some(key) = self.encrypt.as_mut() {
if let Err(err) = key.dec(&mut bytes) {
return Some(Err(err));
}
}
return Some(Ok(bytes));
}
WsMessage::Text(text) => {
let bytes = BytesMut::from(text.as_bytes());
return Some(Ok(bytes));
}
WsMessage::Close(_) => {
return None;
}
_ => {
continue;
}
}
}
None
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
match timeout(Duration::from_millis(ms), self.next()).await {
Ok(res) => res,
Err(_) => None,
}
}
}
pub fn is_ws_endpoint(endpoint: &str) -> bool {
endpoint.starts_with("ws://") || endpoint.starts_with("wss://")
}
/**
* Core function to convert an endpoint to WebSocket format
*
* Converts between different address formats:
* 1. IPv4 address with/without port -> ws://ipv4:port
* 2. IPv6 address with/without port -> ws://[ipv6]:port
* 3. Domain with/without port -> ws(s)://domain/ws/path
*
* @param endpoint The endpoint to convert
* @return The converted WebSocket endpoint
*/
pub fn check_ws(endpoint: &str) -> String {
if !use_ws() {
return endpoint.to_string();
}
if endpoint.is_empty() {
return endpoint.to_string();
}
if is_ws_endpoint(endpoint) {
return endpoint.to_string();
}
let Some((endpoint_host, endpoint_port)) = split_host_port(endpoint) else {
debug_assert!(false, "endpoint doesn't have port");
return endpoint.to_string();
};
let custom_rendezvous_server = Config::get_rendezvous_server();
let relay_server = Config::get_option(OPTION_RELAY_SERVER);
let rendezvous_port = split_host_port(&custom_rendezvous_server)
.map(|(_, p)| p)
.unwrap_or(RENDEZVOUS_PORT);
let relay_port = split_host_port(&relay_server)
.map(|(_, p)| p)
.unwrap_or(RELAY_PORT);
let (relay, dst_port) = if endpoint_port == rendezvous_port {
// rendezvous
(false, endpoint_port + 2)
} else if endpoint_port == rendezvous_port - 1 {
// online
(false, endpoint_port + 3)
} else if endpoint_port == relay_port || endpoint_port == rendezvous_port + 1 {
// relay
// https://github.com/rustdesk/rustdesk/blob/6ffbcd1375771f2482ec4810680623a269be70f1/src/rendezvous_mediator.rs#L615
// https://github.com/rustdesk/rustdesk-server/blob/235a3c326ceb665e941edb50ab79faa1208f7507/src/relay_server.rs#L83, based on relay port.
(true, endpoint_port + 2)
} else {
// fallback relay
// for controlling side, relay server is passed from the controlled side, not related to local config.
(true, endpoint_port + 2)
};
let (address, is_domain) = if crate::is_ip_str(endpoint) {
(format!("{}:{}", endpoint_host, dst_port), false)
} else {
let domain_path = if relay { "/ws/relay" } else { "/ws/id" };
(format!("{}{}", endpoint_host, domain_path), true)
};
let protocol = if is_domain {
let api_server = Config::get_option("api-server");
if api_server.starts_with("https") {
"wss"
} else {
"ws"
}
} else {
"ws"
};
format!("{}://{}", protocol, address)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{keys, Config};
#[test]
fn test_check_ws() {
// enable websocket
Config::set_option(keys::OPTION_ALLOW_WEBSOCKET.to_string(), "Y".to_string());
// not set custom-rendezvous-server
Config::set_option("custom-rendezvous-server".to_string(), "".to_string());
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
assert_eq!(check_ws("rustdesk.com:21115"), "ws://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21116"), "ws://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21117"), "ws://rustdesk.com/ws/relay");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
Config::set_option(
"api-server".to_string(),
"https://api.rustdesk.com".to_string(),
);
assert_eq!(
check_ws("[0:0:0:0:0:0:0:1]:21115"),
"ws://[0:0:0:0:0:0:0:1]:21118"
);
assert_eq!(
check_ws("[0:0:0:0:0:0:0:1]:21116"),
"ws://[0:0:0:0:0:0:0:1]:21118"
);
assert_eq!(
check_ws("[0:0:0:0:0:0:0:1]:21117"),
"ws://[0:0:0:0:0:0:0:1]:21119"
);
assert_eq!(check_ws("rustdesk.com:21115"), "wss://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21116"), "wss://rustdesk.com/ws/id");
assert_eq!(
check_ws("rustdesk.com:21117"),
"wss://rustdesk.com/ws/relay"
);
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("rustdesk.com:21115"), "wss://rustdesk.com/ws/id");
assert_eq!(check_ws("rustdesk.com:21116"), "wss://rustdesk.com/ws/id");
assert_eq!(
check_ws("rustdesk.com:34567"),
"wss://rustdesk.com/ws/relay"
);
// set custom-rendezvous-server without port
Config::set_option(
"custom-rendezvous-server".to_string(),
"127.0.0.1".to_string(),
);
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569");
// set custom-rendezvous-server without default port
Config::set_option(
"custom-rendezvous-server".to_string(),
"127.0.0.1".to_string(),
);
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("127.0.0.1:21115"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:21116"), "ws://127.0.0.1:21118");
assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569");
// set custom-rendezvous-server with custom port
Config::set_option(
"custom-rendezvous-server".to_string(),
"127.0.0.1:23456".to_string(),
);
Config::set_option("relay-server".to_string(), "".to_string());
Config::set_option("api-server".to_string(), "".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23457"), "ws://127.0.0.1:23459");
// set relay-server without port
Config::set_option("relay-server".to_string(), "127.0.0.1".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with default port
Config::set_option("relay-server".to_string(), "127.0.0.1:21117".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:21117"), "ws://127.0.0.1:21119");
// set relay-server with custom port
Config::set_option("relay-server".to_string(), "127.0.0.1:34567".to_string());
assert_eq!(check_ws("127.0.0.1:23455"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:23456"), "ws://127.0.0.1:23458");
assert_eq!(check_ws("127.0.0.1:34567"), "ws://127.0.0.1:34569");
}
}