diff --git a/Cargo.toml b/Cargo.toml index b5a70f2..e952e86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ rustls-pki-types = "1.11" rustls-native-certs = "0.8" webpki-roots = "1.0.4" async-recursion = "1.1" +webrtc = "0.14.0" [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" @@ -70,6 +71,10 @@ machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" } [build-dependencies] protobuf-codegen = { version = "3.7" } +[dev-dependencies] +clap = "4.5.51" +webrtc-signal = "0.1.1" + [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = [ "winuser", diff --git a/examples/webrtc.rs b/examples/webrtc.rs new file mode 100644 index 0000000..abd7a0b --- /dev/null +++ b/examples/webrtc.rs @@ -0,0 +1,243 @@ +use std::io::Write; +use std::sync::Arc; + +use bytes::{Bytes, BytesMut}; + +use clap::{Arg, Command}; +use anyhow::Result; +use tokio::time::Duration; + +use webrtc::api::APIBuilder; +use webrtc::api::setting_engine::SettingEngine; +use webrtc::data_channel::RTCDataChannel; +use webrtc::ice_transport::ice_server::RTCIceServer; +use webrtc::peer_connection::configuration::RTCConfiguration; +use webrtc::peer_connection::math_rand_alpha; +use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; +use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; + +use webrtc_signal::{self as signal}; + +// example from https://github.com/webrtc-rs/webrtc/tree/master/examples/examples/data-channels +#[tokio::main] +async fn main() -> Result<()> { + let mut app = Command::new("data-channels") + .version("0.1.0") + .author("Rain Liu ") + .about("An example of Data-Channels.") + .arg( + Arg::new("FULLHELP") + .help("Prints more detailed help information") + .long("fullhelp"), + ) + .arg( + Arg::new("debug") + .long("debug") + .short('d') + .help("Prints debug log information"), + ); + + let matches = app.clone().get_matches(); + + if matches.contains_id("FULLHELP") { + app.print_long_help().unwrap(); + std::process::exit(0); + } + + let debug = matches.contains_id("debug"); + if debug { + env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, log::LevelFilter::Trace) + .init(); + } + + // Everything below is the WebRTC-rs API! Thanks for using it ❤️. + // Create a SettingEngine and enable Detach + let mut s = SettingEngine::default(); + s.detach_data_channels(); + + // Create the API object + let api = APIBuilder::new() + .with_setting_engine(s) + .build(); + + // Prepare the configuration + let config = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: vec!["stun:stun.l.google.com:19302".to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + // Create a new RTCPeerConnection + let peer_connection = Arc::new(api.new_peer_connection(config).await?); + + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + + let bootstrap = peer_connection.create_data_channel("bootstrap", None).await?; + let bootstrap_clone = Arc::clone(&bootstrap); + bootstrap.on_open(Box::new(move || { + println!("Data channel bootstrap open."); + Box::pin(async move { + let _raw = match bootstrap_clone.detach().await { + Ok(raw) => raw, + Err(err) => { + println!("data channel detach got err: {err}"); + return; + } + }; + }) + })); + + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { + println!("Peer Connection State has changed: {s}"); + + if s == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. + // It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = done_tx.try_send(()); + } + + Box::pin(async {}) + })); + + + // Register data channel creation handling + peer_connection.on_data_channel(Box::new(move |d: Arc| { + let d_label = d.label().to_owned(); + let d_id = d.id(); + println!("New DataChannel {d_label} {d_id}"); + + // Register channel opening handling + Box::pin(async move { + let d2 = Arc::clone(&d); + let d3 = Arc::clone(&d); + let d_label2 = d_label.clone(); + let d_id2 = d_id; + d.on_open(Box::new(move || { + println!("Data channel '{d_label2}'-'{d_id2}' open."); + + Box::pin(async move { + tokio::spawn(async move { + let _ = read_loop(d2).await; + }); + + // Handle writing to the data channel + tokio::spawn(async move { + let _ = write_loop(d3).await; + }); + }) + })); + }) + })); + + // Wait for the offer to be pasted + println!("Wait for the offer to be pasted"); + let line = signal::must_read_stdin()?; + let desc_data = signal::decode(line.as_str())?; + let offer = serde_json::from_str::(&desc_data)?; + + // Set the remote SessionDescription + peer_connection.set_remote_description(offer).await?; + + // Create an answer + let answer = peer_connection.create_answer(None).await?; + + // Create channel that is blocked until ICE Gathering is complete + let mut gather_complete = peer_connection.gathering_complete_promise().await; + + // Sets the LocalDescription, and starts our UDP listeners + peer_connection.set_local_description(answer).await?; + + // Block until ICE Gathering is complete, disabling trickle ICE + // we do this because we only can exchange one signaling message + // in a production application you should exchange ICE Candidates via OnICECandidate + let _ = gather_complete.recv().await; + + // Output the answer in base64 so we can paste it in browser + if let Some(local_desc) = peer_connection.local_description().await { + let json_str = serde_json::to_string(&local_desc)?; + println!("{json_str}"); + let b64 = signal::encode(&json_str); + println!("--------------------- Copy the below base64 to browser --------------------"); + println!("{b64}"); + } else { + println!("generate local_description failed!"); + } + + println!("Press ctrl-c to stop"); + tokio::select! { + _ = done_rx.recv() => { + println!("received done signal!"); + } + _ = tokio::signal::ctrl_c() => { + println!(); + } + }; + + peer_connection.close().await?; + + Ok(()) +} + +// read_loop shows how to read from the datachannel directly +async fn read_loop(dc: Arc) -> Result<()> { + let mut buffer = BytesMut::zeroed(4096); + loop { + let d = dc.detach().await?; + println!("RTCDatachannel detach ok"); + let n = match d.read(&mut buffer).await { + Ok(n) => n, + Err(err) => { + println!("Datachannel closed; Exit the read_loop: {err}"); + return Ok(()); + } + }; + + if n == 0 { + println!("Datachannel read 0 byte; Exit the read_loop"); + return Ok(()); + } + println!( + "Message from DataChannel: {}", + String::from_utf8(buffer[..n].to_vec())? + ); + } +} + +// write_loop shows how to write to the datachannel directly +async fn write_loop(d: Arc) -> Result<()> { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + let message = math_rand_alpha(15); + println!("Sending '{message}'"); + result = d.send(&Bytes::from(message)).await.map_err(Into::into); + } + }; + } + println!("Datachannel write not ok; Exit the write_loop"); + + Ok(()) +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 851e4b1..1504290 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,6 +59,7 @@ pub mod fingerprint; pub use flexi_logger; pub mod stream; pub mod websocket; +pub mod webrtc; #[cfg(any(target_os = "android", target_os = "ios"))] pub use rustls_platform_verifier; pub use stream::Stream; diff --git a/src/socket_client.rs b/src/socket_client.rs index f0e5b05..c6db932 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -3,6 +3,7 @@ use crate::{ tcp::FramedStream, udp::FramedSocket, websocket::{self, check_ws, is_ws_endpoint}, + webrtc::{self, is_webrtc_endpoint}, ResultType, Stream, }; use anyhow::Context; @@ -129,6 +130,11 @@ pub async fn connect_tcp< target: T, ms_timeout: u64, ) -> ResultType { + if is_webrtc_endpoint(&target.to_string()) { + return Ok(Stream::WebRTC( + webrtc::WebRTCStream::new(&target.to_string(), ms_timeout).await?, + )); + } let target_str = check_ws(&target.to_string()); if is_ws_endpoint(&target_str) { return Ok(Stream::WebSocket( diff --git a/src/stream.rs b/src/stream.rs index 987d9be..16db2c5 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,10 +1,11 @@ -use crate::{config, tcp, websocket, ResultType}; +use crate::{config, tcp, websocket, webrtc, ResultType}; use sodiumoxide::crypto::secretbox::Key; use std::net::SocketAddr; use tokio::net::TcpStream; // support Websocket and tcp. pub enum Stream { + WebRTC(webrtc::WebRTCStream), WebSocket(websocket::WsFramedStream), Tcp(tcp::FramedStream), } @@ -13,6 +14,7 @@ impl Stream { #[inline] pub fn set_send_timeout(&mut self, ms: u64) { match self { + Stream::WebRTC(s) => s.set_send_timeout(ms), Stream::WebSocket(s) => s.set_send_timeout(ms), Stream::Tcp(s) => s.set_send_timeout(ms), } @@ -21,6 +23,7 @@ impl Stream { #[inline] pub fn set_raw(&mut self) { match self { + Stream::WebRTC(s) => s.set_raw(), Stream::WebSocket(s) => s.set_raw(), Stream::Tcp(s) => s.set_raw(), } @@ -29,6 +32,7 @@ impl Stream { #[inline] pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { match self { + Stream::WebRTC(s) => s.send_bytes(bytes).await, Stream::WebSocket(s) => s.send_bytes(bytes).await, Stream::Tcp(s) => s.send_bytes(bytes).await, } @@ -37,6 +41,7 @@ impl Stream { #[inline] pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { match self { + Stream::WebRTC(s) => s.send_raw(bytes).await, Stream::WebSocket(s) => s.send_raw(bytes).await, Stream::Tcp(s) => s.send_raw(bytes).await, } @@ -45,6 +50,7 @@ impl Stream { #[inline] pub fn set_key(&mut self, key: Key) { match self { + Stream::WebRTC(s) => s.set_key(key), Stream::WebSocket(s) => s.set_key(key), Stream::Tcp(s) => s.set_key(key), } @@ -53,6 +59,7 @@ impl Stream { #[inline] pub fn is_secured(&self) -> bool { match self { + Stream::WebRTC(s) => s.is_secured(), Stream::WebSocket(s) => s.is_secured(), Stream::Tcp(s) => s.is_secured(), } @@ -64,6 +71,7 @@ impl Stream { timeout: u64, ) -> Option> { match self { + Stream::WebRTC(s) => s.next_timeout(timeout).await, Stream::WebSocket(s) => s.next_timeout(timeout).await, Stream::Tcp(s) => s.next_timeout(timeout).await, } @@ -87,6 +95,7 @@ impl Stream { #[inline] pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { match self { + Self::WebRTC(s) => s.send(msg).await, Self::WebSocket(ws) => ws.send(msg).await, Self::Tcp(tcp) => tcp.send(msg).await, } @@ -96,6 +105,7 @@ impl Stream { #[inline] pub async fn next(&mut self) -> Option> { match self { + Self::WebRTC(s) => s.next().await, Self::WebSocket(ws) => ws.next().await, Self::Tcp(tcp) => tcp.next().await, } @@ -104,6 +114,7 @@ impl Stream { #[inline] pub fn local_addr(&self) -> SocketAddr { match self { + Self::WebRTC(s) => s.local_addr(), Self::WebSocket(ws) => ws.local_addr(), Self::Tcp(tcp) => tcp.local_addr(), } diff --git a/src/webrtc.rs b/src/webrtc.rs new file mode 100644 index 0000000..5b61eee --- /dev/null +++ b/src/webrtc.rs @@ -0,0 +1,269 @@ +use std::sync::{Arc}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::io::{Error, ErrorKind}; +use std::time::Duration; +use std::collections::HashMap; + +use webrtc::api::APIBuilder; +use webrtc::api::setting_engine::SettingEngine; +use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::data_channel_state::RTCDataChannelState; +use webrtc::ice_transport::ice_server::RTCIceServer; +use webrtc::peer_connection::RTCPeerConnection; +use webrtc::peer_connection::configuration::RTCConfiguration; +use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; +use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; + +use crate::{ + protobuf::Message, + sodiumoxide::crypto::secretbox::Key, + ResultType, +}; +use bytes::{Bytes, BytesMut}; +use tokio::{time::timeout}; +use tokio::sync::Notify; +use tokio::sync::Mutex; +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; + +pub struct WebRTCStream { + pc: Arc, + stream: Arc, + notify: Arc, + send_timeout: u64, +} + +/// message size limit for Chromium +const DATA_CHANNEL_BUFFER_SIZE: u16 = u16::MAX; + +lazy_static::lazy_static! { + static ref SESSIONS: Arc::>> = Default::default(); +} + +impl Clone for WebRTCStream { + fn clone(&self) -> Self { + WebRTCStream { + pc: self.pc.clone(), + stream: self.stream.clone(), + notify: self.notify.clone(), + send_timeout: self.send_timeout, + } + } +} + +impl WebRTCStream { + + pub fn get_remote_offer(endpoint: &str) -> Option { + // Ensure the endpoint starts with the "webrtc://" prefix + if !endpoint.starts_with("webrtc://") { + return None; + } + + // Extract the Base64-encoded SDP part + let encoded_sdp = &endpoint["webrtc://".len()..]; + + // Decode the Base64 string + let decoded_bytes = BASE64_STANDARD.decode(encoded_sdp).ok()?; + let decoded_sdp = String::from_utf8(decoded_bytes).ok()?; + + Some(decoded_sdp) + } + + pub async fn new>( + webrtc_endpoint: T, + ms_timeout: u64, + ) -> ResultType { + log::debug!("Start webrtc with endpoint: {}", webrtc_endpoint.as_ref()); + let remote_offer: String = match Self::get_remote_offer(webrtc_endpoint.as_ref()) { + Some(offer) => offer, + None => { + return Err(Error::new( + ErrorKind::InvalidInput, + "Invalid WebRTC endpoint format", + ).into()); + } + }; + + let key = remote_offer.to_string(); + let mut lock = SESSIONS.lock().await; + let contains = lock.contains_key(&key); + if contains { + log::debug!("Start webrtc with cached peer"); + return Ok(lock.get(&key).unwrap().clone()); + } + + log::debug!("Start webrtc with offer: {}", remote_offer); + // Create a SettingEngine and enable Detach + let mut s = SettingEngine::default(); + s.detach_data_channels(); + + // Create the API object + let api = APIBuilder::new() + .with_setting_engine(s) + .build(); + + // Prepare the configuration + let config = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: vec!["stun:stun.cloudflare.com:3478".to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + let notify = Arc::new(Notify::new()); + let notify_tx = notify.clone(); + // Create a new RTCPeerConnection + let peer_connection = Arc::new(api.new_peer_connection(config).await?); + let bootstrap = peer_connection.create_data_channel("bootstrap", None).await?; + bootstrap.on_open(Box::new(move || { + log::debug!("Data channel bootstrap open."); + notify_tx.notify_waiters(); + Box::pin(async {}) + })); + + // This will notify you when the peer has connected/disconnected + let notify_tx2 = notify.clone(); + peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { + log::debug!("Peer Connection State has changed: {}", s); + if s == RTCPeerConnectionState::Disconnected { + notify_tx2.notify_waiters(); + } + + // TODO clear SESSIONS entry? + Box::pin(async {}) + })); + + let offer = serde_json::from_str::(&remote_offer)?; + // Set the remote SessionDescription + peer_connection.set_remote_description(offer).await?; + // Create an answer + let answer = peer_connection.create_answer(None).await?; + // Create channel that is blocked until ICE Gathering is complete + let mut gather_complete = peer_connection.gathering_complete_promise().await; + // Sets the LocalDescription, and starts our UDP listeners + peer_connection.set_local_description(answer).await?; + let _ = gather_complete.recv().await; + + let ds = WebRTCStream { + pc: peer_connection, + stream: bootstrap, + notify: notify, + send_timeout: ms_timeout, + }; + + // log the answer + match ds.get_local_endpoint().await { + Some(local_endpoint) => log::debug!("WebRTC local endpoint: {}", local_endpoint), + None => log::debug!("WebRTC local endpoint: "), + } + + lock.insert(key, ds.clone()); + Ok(ds) + } + + #[inline] + pub async fn get_local_endpoint(&self) -> Option { + if let Some(local_desc) = self.pc.local_description().await { + let sdp = serde_json::to_string(&local_desc).ok()?; + Some(format!("webrtc://{}", BASE64_STANDARD.encode(sdp))) + } else { + None + } + } + + #[inline] + pub fn set_raw(&mut self) { + // not-supported + } + + #[inline] + pub fn local_addr(&self) -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } + + #[inline] + pub fn set_send_timeout(&mut self, ms: u64) { + self.send_timeout = ms; + } + + #[inline] + pub fn set_key(&mut self, _key: Key) { + // not-supported + } + + #[inline] + pub fn is_secured(&self) -> bool { + true + } + + #[inline] + pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { + self.send_raw(msg.write_to_bytes()?).await + } + + #[inline] + pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { + self.send_bytes(Bytes::from(msg)).await + } + + pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + // wait for connected or disconnected + self.notify.notified().await; + self.stream.send(&bytes).await?; + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option> { + // wait for connected or disconnected + self.notify.notified().await; + if self.stream.ready_state() != RTCDataChannelState::Open { + return Some(Err(Error::new( + ErrorKind::Other, + "data channel is closed", + ))); + } + + // TODO reuse buffer? + let mut buffer = BytesMut::zeroed(DATA_CHANNEL_BUFFER_SIZE as usize); + let dc = self.stream.detach().await.ok()?; + let n = match dc.read(&mut buffer).await { + Ok(n) => n, + Err(err) => { + return Some(Err(Error::new( + ErrorKind::Other, + format!("data channel read error: {}", err), + ))); + } + }; + if n == 0 { + return Some(Err(Error::new( + ErrorKind::Other, + "data channel read exited with 0 bytes", + ))); + } + buffer.truncate(n); + Some(Ok(buffer)) + } + + #[inline] + pub async fn next_timeout(&mut self, ms: u64) -> Option> { + match timeout(Duration::from_millis(ms), self.next()).await { + Ok(res) => res, + Err(_) => None, + } + } +} + +pub fn is_webrtc_endpoint(endpoint: &str) -> bool { + // use sdp base64 json string as endpoint, or prefix webrtc: + endpoint.starts_with("webrtc://") +} + +#[cfg(test)] +mod tests { + #[test] + fn test_dc() { + } +}