From 8ae4651bc7466c192443f29e7e05029819c36708 Mon Sep 17 00:00:00 2001 From: lc Date: Thu, 13 Nov 2025 16:53:04 +0800 Subject: [PATCH] make webrtc-rs optional feature --- Cargo.toml | 8 +- examples/webrtc.rs | 218 ++++++++++--------------------------------- src/lib.rs | 5 + src/socket_client.rs | 4 +- src/webrtc.rs | 153 ++++++++++++++++++------------ src/webrtc_dummy.rs | 67 +++++++++++++ 6 files changed, 223 insertions(+), 232 deletions(-) create mode 100644 src/webrtc_dummy.rs diff --git a/Cargo.toml b/Cargo.toml index e952e86..72f0919 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,10 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["webrtc"] +webrtc = ["dep:webrtc"] + [dependencies] # new flexi_logger failed on rustc 1.75 flexi_logger = { version = "0.27", features = ["async"] } @@ -61,7 +65,7 @@ rustls-pki-types = "1.11" rustls-native-certs = "0.8" webpki-roots = "1.0.4" async-recursion = "1.1" -webrtc = "0.14.0" +webrtc = { version = "0.14.0", optional = true } [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" @@ -73,7 +77,7 @@ protobuf-codegen = { version = "3.7" } [dev-dependencies] clap = "4.5.51" -webrtc-signal = "0.1.1" +webrtc = "0.14.0" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = [ diff --git a/examples/webrtc.rs b/examples/webrtc.rs index abd7a0b..5a5e909 100644 --- a/examples/webrtc.rs +++ b/examples/webrtc.rs @@ -1,51 +1,37 @@ -use std::io::Write; -use std::sync::Arc; +extern crate hbb_common; -use bytes::{Bytes, BytesMut}; +use std::io::Write; +use bytes::Bytes; use clap::{Arg, Command}; use anyhow::Result; use tokio::time::Duration; -use webrtc::api::APIBuilder; -use webrtc::api::setting_engine::SettingEngine; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::math_rand_alpha; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc_signal::{self as signal}; - -// example from https://github.com/webrtc-rs/webrtc/tree/master/examples/examples/data-channels #[tokio::main] async fn main() -> Result<()> { - let mut app = Command::new("data-channels") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of Data-Channels.") - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) + let app = Command::new("webrtc-stream") + .about("An example of webrtc stream using hbb_common and webrtc-rs") .arg( Arg::new("debug") .long("debug") .short('d') + .action(clap::ArgAction::SetTrue) .help("Prints debug log information"), + ) + .arg( + Arg::new("offer") + .long("offer") + .short('o') + .help("set offer from other endpoint"), ); let matches = app.clone().get_matches(); - if matches.contains_id("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - let debug = matches.contains_id("debug"); if debug { + println!("Debug log enabled"); env_logger::Builder::new() .format(|buf, record| { writeln!( @@ -58,173 +44,67 @@ async fn main() -> Result<()> { record.args() ) }) - .filter(None, log::LevelFilter::Trace) + .filter(None, log::LevelFilter::Debug) .init(); } - // Everything below is the WebRTC-rs API! Thanks for using it ❤️. - // Create a SettingEngine and enable Detach - let mut s = SettingEngine::default(); - s.detach_data_channels(); - - // Create the API object - let api = APIBuilder::new() - .with_setting_engine(s) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() + let remote_endpoint = if let Some(endpoint) = matches.get_one::("offer") { + endpoint.to_string() + } else { + "".to_string() }; - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); + let webrtc_stream = hbb_common::webrtc::WebRTCStream::new(&remote_endpoint, 30000).await?; + // Print the offer to be sent to the other peer + webrtc_stream.get_local_endpoint().await; - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - let bootstrap = peer_connection.create_data_channel("bootstrap", None).await?; - let bootstrap_clone = Arc::clone(&bootstrap); - bootstrap.on_open(Box::new(move || { - println!("Data channel bootstrap open."); - Box::pin(async move { - let _raw = match bootstrap_clone.detach().await { - Ok(raw) => raw, - Err(err) => { - println!("data channel detach got err: {err}"); - return; - } - }; - }) - })); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. - // It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - - // Register data channel creation handling - peer_connection.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d3 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open."); - - Box::pin(async move { - tokio::spawn(async move { - let _ = read_loop(d2).await; - }); - - // Handle writing to the data channel - tokio::spawn(async move { - let _ = write_loop(d3).await; - }); - }) - })); - }) - })); - - // Wait for the offer to be pasted - println!("Wait for the offer to be pasted"); - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - println!("{json_str}"); - let b64 = signal::encode(&json_str); - println!("--------------------- Copy the below base64 to browser --------------------"); - println!("{b64}"); - } else { - println!("generate local_description failed!"); + if remote_endpoint.is_empty() { + // Wait for the answer to be pasted + println!("Wait for the answer to be pasted"); + // readline blocking + let line = std::io::stdin() + .lines() + .next() + .ok_or_else(|| anyhow::anyhow!("No input received"))??; + webrtc_stream.set_remote_endpoint(&line).await?; } + let s1 = hbb_common::Stream::WebRTC(webrtc_stream.clone()); + tokio::spawn(async move { + let _ = read_loop(s1).await; + }); + + let s2 = hbb_common::Stream::WebRTC(webrtc_stream.clone()); + tokio::spawn(async move { + let _ = write_loop(s2).await; + }); + println!("Press ctrl-c to stop"); tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } _ = tokio::signal::ctrl_c() => { println!(); } }; - peer_connection.close().await?; - Ok(()) } // read_loop shows how to read from the datachannel directly -async fn read_loop(dc: Arc) -> Result<()> { - let mut buffer = BytesMut::zeroed(4096); +async fn read_loop(mut stream: hbb_common::Stream) -> Result<()> { loop { - let d = dc.detach().await?; - println!("RTCDatachannel detach ok"); - let n = match d.read(&mut buffer).await { - Ok(n) => n, - Err(err) => { - println!("Datachannel closed; Exit the read_loop: {err}"); - return Ok(()); - } - }; - - if n == 0 { - println!("Datachannel read 0 byte; Exit the read_loop"); + let Some(res) = stream.next().await else { + println!("Datachannel closed; Exit the read_loop"); return Ok(()); - } - println!( - "Message from DataChannel: {}", - String::from_utf8(buffer[..n].to_vec())? + }; + println!("Message from DataChannel: {}", + String::from_utf8(res.unwrap().to_vec())? ); } } // write_loop shows how to write to the datachannel directly -async fn write_loop(d: Arc) -> Result<()> { - let mut result = Result::::Ok(0); +async fn write_loop(mut stream: hbb_common::Stream) -> Result<()> { + let mut result = Result::<()>::Ok(()); while result.is_ok() { let timeout = tokio::time::sleep(Duration::from_secs(5)); tokio::pin!(timeout); @@ -233,7 +113,7 @@ async fn write_loop(d: Arc) -> Result<()> { _ = timeout.as_mut() =>{ let message = math_rand_alpha(15); println!("Sending '{message}'"); - result = d.send(&Bytes::from(message)).await.map_err(Into::into); + result = stream.send_bytes(Bytes::from(message)).await; } }; } diff --git a/src/lib.rs b/src/lib.rs index 1504290..5d3e600 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,7 +59,12 @@ pub mod fingerprint; pub use flexi_logger; pub mod stream; pub mod websocket; +#[cfg(feature = "webrtc")] pub mod webrtc; +#[cfg(not(feature = "webrtc"))] +pub mod webrtc_dummy; +#[cfg(not(feature = "webrtc"))] +pub use webrtc_dummy as webrtc; #[cfg(any(target_os = "android", target_os = "ios"))] pub use rustls_platform_verifier; pub use stream::Stream; diff --git a/src/socket_client.rs b/src/socket_client.rs index c6db932..1f568ff 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -3,9 +3,10 @@ use crate::{ tcp::FramedStream, udp::FramedSocket, websocket::{self, check_ws, is_ws_endpoint}, - webrtc::{self, is_webrtc_endpoint}, ResultType, Stream, }; +#[cfg(feature = "webrtc")] +use crate::webrtc::{self, is_webrtc_endpoint}; use anyhow::Context; use std::{net::SocketAddr, sync::Arc}; use tokio::net::{ToSocketAddrs, UdpSocket}; @@ -130,6 +131,7 @@ pub async fn connect_tcp< target: T, ms_timeout: u64, ) -> ResultType { + #[cfg(feature = "webrtc")] if is_webrtc_endpoint(&target.to_string()) { return Ok(Stream::WebRTC( webrtc::WebRTCStream::new(&target.to_string(), ms_timeout).await?, diff --git a/src/webrtc.rs b/src/webrtc.rs index 5b61eee..b276e77 100644 --- a/src/webrtc.rs +++ b/src/webrtc.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc}; +use std::sync::Arc; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::io::{Error, ErrorKind}; use std::time::Duration; @@ -13,23 +13,25 @@ use webrtc::peer_connection::RTCPeerConnection; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::ice::mdns::MulticastDnsMode; use crate::{ protobuf::Message, sodiumoxide::crypto::secretbox::Key, ResultType, }; -use bytes::{Bytes, BytesMut}; -use tokio::{time::timeout}; -use tokio::sync::Notify; -use tokio::sync::Mutex; + use base64::Engine; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use bytes::{Bytes, BytesMut}; +use tokio::time::timeout; +use tokio::sync::watch; +use tokio::sync::Mutex; pub struct WebRTCStream { pc: Arc, stream: Arc, - notify: Arc, + state_notify: watch::Receiver, send_timeout: u64, } @@ -45,7 +47,7 @@ impl Clone for WebRTCStream { WebRTCStream { pc: self.pc.clone(), stream: self.stream.clone(), - notify: self.notify.clone(), + state_notify: self.state_notify.clone(), send_timeout: self.send_timeout, } } @@ -53,38 +55,40 @@ impl Clone for WebRTCStream { impl WebRTCStream { - pub fn get_remote_offer(endpoint: &str) -> Option { + pub fn get_remote_offer(endpoint: &str) -> ResultType { // Ensure the endpoint starts with the "webrtc://" prefix if !endpoint.starts_with("webrtc://") { - return None; + return Err(Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into()); } // Extract the Base64-encoded SDP part let encoded_sdp = &endpoint["webrtc://".len()..]; - // Decode the Base64 string - let decoded_bytes = BASE64_STANDARD.decode(encoded_sdp).ok()?; - let decoded_sdp = String::from_utf8(decoded_bytes).ok()?; - - Some(decoded_sdp) + let decoded_bytes = BASE64_STANDARD.decode(encoded_sdp).map_err(|_| + Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP") + )?; + Ok(String::from_utf8(decoded_bytes).map_err(|_| { + Error::new(ErrorKind::InvalidInput, "Failed to convert decoded bytes to UTF-8") + })?) } - pub async fn new>( - webrtc_endpoint: T, + pub fn sdp_to_endpoint(sdp: &str) -> String { + let encoded_sdp = BASE64_STANDARD.encode(sdp); + format!("webrtc://{}", encoded_sdp) + } + + pub async fn new( + remote_endpoint: &str, ms_timeout: u64, ) -> ResultType { - log::debug!("Start webrtc with endpoint: {}", webrtc_endpoint.as_ref()); - let remote_offer: String = match Self::get_remote_offer(webrtc_endpoint.as_ref()) { - Some(offer) => offer, - None => { - return Err(Error::new( - ErrorKind::InvalidInput, - "Invalid WebRTC endpoint format", - ).into()); - } + log::debug!("New webrtc stream with endpoint: {}", remote_endpoint); + let remote_offer = if remote_endpoint.is_empty() { + "".into() + } else { + Self::get_remote_offer(remote_endpoint)? }; - let key = remote_offer.to_string(); + let mut key = remote_offer.clone(); let mut lock = SESSIONS.lock().await; let contains = lock.contains_key(&key); if contains { @@ -92,10 +96,10 @@ impl WebRTCStream { return Ok(lock.get(&key).unwrap().clone()); } - log::debug!("Start webrtc with offer: {}", remote_offer); // Create a SettingEngine and enable Detach let mut s = SettingEngine::default(); s.detach_data_channels(); + s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); // Create the API object let api = APIBuilder::new() @@ -111,67 +115,96 @@ impl WebRTCStream { ..Default::default() }; - let notify = Arc::new(Notify::new()); - let notify_tx = notify.clone(); + let (notify_tx, notify_rx) = watch::channel(false); + let on_open_notify = notify_tx.clone(); // Create a new RTCPeerConnection let peer_connection = Arc::new(api.new_peer_connection(config).await?); - let bootstrap = peer_connection.create_data_channel("bootstrap", None).await?; - bootstrap.on_open(Box::new(move || { + let data_channel = peer_connection.create_data_channel("bootstrap", None).await?; + data_channel.on_open(Box::new(move || { log::debug!("Data channel bootstrap open."); - notify_tx.notify_waiters(); + let _ = on_open_notify.send(true); Box::pin(async {}) })); // This will notify you when the peer has connected/disconnected - let notify_tx2 = notify.clone(); + let on_connection_notify = notify_tx.clone(); peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { log::debug!("Peer Connection State has changed: {}", s); if s == RTCPeerConnectionState::Disconnected { - notify_tx2.notify_waiters(); + let _ = on_connection_notify.send(true); } // TODO clear SESSIONS entry? Box::pin(async {}) })); - let offer = serde_json::from_str::(&remote_offer)?; - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - // Create an answer - let answer = peer_connection.create_answer(None).await?; - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - let _ = gather_complete.recv().await; + // Register data channel creation handling + let on_open_notify2 = notify_tx.clone(); + peer_connection.on_data_channel(Box::new(move |dc: Arc| { + let d_label = dc.label().to_owned(); + log::debug!("Remote data channel {}", d_label); + let notify = on_open_notify2.clone(); + Box::pin(async move { + dc.on_open(Box::new(move || { + let _ = notify.send(true); + Box::pin(async {}) + })); + }) + })); - let ds = WebRTCStream { + if remote_offer.is_empty() { + let sdp = peer_connection.create_offer(None).await?; + let mut gather_complete = peer_connection.gathering_complete_promise().await; + peer_connection.set_local_description(sdp.clone()).await?; + let _ = gather_complete.recv().await; + + let final_sdp = peer_connection.local_description().await.ok_or_else(|| { + Error::new(ErrorKind::Other, "Failed to get local description after gathering") + })?; + key = serde_json::to_string(&final_sdp).unwrap_or_default(); + log::debug!("Start webrtc with local: {}", key); + } else { + let sdp = serde_json::from_str::(&remote_offer)?; + peer_connection.set_remote_description(sdp).await?; + let answer = peer_connection.create_answer(None).await?; + let mut gather_complete = peer_connection.gathering_complete_promise().await; + peer_connection.set_local_description(answer).await?; + let _ = gather_complete.recv().await; + log::debug!("Start webrtc with remote: {}", remote_offer); + } + + let webrtc_stream = WebRTCStream { pc: peer_connection, - stream: bootstrap, - notify: notify, + stream: data_channel, + state_notify: notify_rx, send_timeout: ms_timeout, }; - // log the answer - match ds.get_local_endpoint().await { - Some(local_endpoint) => log::debug!("WebRTC local endpoint: {}", local_endpoint), - None => log::debug!("WebRTC local endpoint: "), - } - - lock.insert(key, ds.clone()); - Ok(ds) + lock.insert(key, webrtc_stream.clone()); + Ok(webrtc_stream) } #[inline] pub async fn get_local_endpoint(&self) -> Option { if let Some(local_desc) = self.pc.local_description().await { - let sdp = serde_json::to_string(&local_desc).ok()?; - Some(format!("webrtc://{}", BASE64_STANDARD.encode(sdp))) + let sdp = serde_json::to_string(&local_desc).unwrap_or_default(); + let endpoint = Self::sdp_to_endpoint(&sdp); + log::debug!("WebRTC get local endpoint: {}", endpoint); + Some(endpoint) } else { None } } + #[inline] + pub async fn set_remote_endpoint(&self, endpoint: &str) -> ResultType<()> { + let offer = Self::get_remote_offer(endpoint)?; + log::debug!("WebRTC set remote sdp: {}", offer); + let sdp = serde_json::from_str::(&offer)?; + self.pc.set_remote_description(sdp).await?; + Ok(()) + } + #[inline] pub fn set_raw(&mut self) { // not-supported @@ -208,8 +241,7 @@ impl WebRTCStream { } pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - // wait for connected or disconnected - self.notify.notified().await; + let _ = self.state_notify.changed().await; self.stream.send(&bytes).await?; Ok(()) } @@ -217,7 +249,7 @@ impl WebRTCStream { #[inline] pub async fn next(&mut self) -> Option> { // wait for connected or disconnected - self.notify.notified().await; + let _ = self.state_notify.changed().await; if self.stream.ready_state() != RTCDataChannelState::Open { return Some(Err(Error::new( ErrorKind::Other, @@ -243,6 +275,7 @@ impl WebRTCStream { "data channel read exited with 0 bytes", ))); } + log::debug!("WebRTCStream read {} bytes", n); buffer.truncate(n); Some(Ok(buffer)) } diff --git a/src/webrtc_dummy.rs b/src/webrtc_dummy.rs new file mode 100644 index 0000000..a6f9344 --- /dev/null +++ b/src/webrtc_dummy.rs @@ -0,0 +1,67 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::io::Error; + +use bytes::{Bytes, BytesMut}; + +use crate::{ + protobuf::Message, + sodiumoxide::crypto::secretbox::Key, + ResultType, +}; + +pub struct WebRTCStream { + // mock struct +} + +impl WebRTCStream { + + #[inline] + pub fn set_raw(&mut self) { + } + + #[inline] + pub fn local_addr(&self) -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } + + #[inline] + pub fn set_send_timeout(&mut self, _ms: u64) { + } + + #[inline] + pub fn set_key(&mut self, _key: Key) { + } + + #[inline] + pub fn is_secured(&self) -> bool { + false + } + + #[inline] + pub async fn send(&mut self, _msg: &impl Message) -> ResultType<()> { + Ok(()) + } + + #[inline] + pub async fn send_raw(&mut self, _msg: Vec) -> ResultType<()> { + Ok(()) + } + + pub async fn send_bytes(&mut self, _bytes: Bytes) -> ResultType<()> { + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option> { + None + } + + #[inline] + pub async fn next_timeout(&mut self, _ms: u64) -> Option> { + None + } +} + +pub fn is_webrtc_endpoint(_endpoint: &str) -> bool { + false +}