use crate::{ config::Socks5Server, protobuf::Message, sodiumoxide::crypto::secretbox::{self, Key, Nonce}, ResultType, }; use bytes::{BufMut, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use std::{ io::{Error, ErrorKind}, net::SocketAddr, time::Duration, }; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, }; use tungstenite::protocol::Role; #[derive(Clone)] pub struct Encrypt(Key, u64, u64); pub struct WsFramedStream { stream: WebSocketStream>, addr: SocketAddr, encrypt: Option, send_timeout: u64, // read_buf: BytesMut, } impl WsFramedStream { pub async fn new>( url: T, local_addr: Option, proxy_conf: Option<&Socks5Server>, ms_timeout: u64, ) -> ResultType { let url_str = url.as_ref(); if let Some(proxy_conf) = proxy_conf { // use proxy connect let url_obj = url::Url::parse(url_str)?; let host = url_obj .host_str() .ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?; let port = url_obj .port() .unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 }); let socket = tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port)) .await?; let tcp_stream = socket.into_inner(); let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream); let ws_stream = WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await; let addr = match ws_stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; Ok(Self { stream: ws_stream, addr, encrypt: None, send_timeout: ms_timeout, }) } else { log::info!("{:?}", url_str); let ws_url = format!("ws://{}", url_str); let (stream, _) = connect_async(ws_url).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, #[cfg(feature = "native-tls")] MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, #[cfg(feature = "rustls")] MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; Ok(Self { stream, addr, encrypt: None, send_timeout: ms_timeout, }) } } pub fn set_raw(&mut self) {} pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; Ok(Self { stream: ws_stream, addr, encrypt: None, send_timeout: 0, // read_buf: BytesMut::new(), }) } pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self { let ws_stream = WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; Self { stream: ws_stream, addr, encrypt: None, send_timeout: 0, // read_buf: BytesMut::new(), } } pub fn local_addr(&self) -> SocketAddr { self.addr } pub fn set_send_timeout(&mut self, ms: u64) { self.send_timeout = ms; } pub fn set_key(&mut self, key: Key) { self.encrypt = Some(Encrypt::new(key)); } pub fn is_secured(&self) -> bool { self.encrypt.is_some() } #[inline] pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { self.send_raw(msg.write_to_bytes()?).await } #[inline] pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { let mut msg = msg; if let Some(key) = self.encrypt.as_mut() { msg = key.enc(&msg); } self.send_bytes(bytes::Bytes::from(msg)).await } #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { let msg = WsMessage::Binary(Bytes::from(bytes)); if self.send_timeout > 0 { let send_future = self.stream.send(msg); timeout(Duration::from_millis(self.send_timeout), send_future) .await .map_err(|_| Error::new(ErrorKind::TimedOut, "Send timeout"))? .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; } else { self.stream .send(msg) .await .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; } Ok(()) } #[inline] pub async fn next(&mut self) -> Option> { log::info!("test"); loop { match self.stream.next().await? { Ok(WsMessage::Binary(data)) => { let mut bytes = BytesMut::from(&data[..]); if let Some(key) = self.encrypt.as_mut() { if let Err(e) = key.dec(&mut bytes) { return Some(Err(e)); } } return Some(Ok(bytes)); } Ok(WsMessage::Ping(ping)) => { if let Err(e) = self.stream.send(WsMessage::Pong(ping)).await { return Some(Err(Error::new( ErrorKind::Other, format!("Failed to send pong: {}", e), ))); } continue; } Ok(WsMessage::Close(_)) => return None, Ok(_) => continue, Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), } } } #[inline] pub async fn next_timeout(&mut self, ms: u64) -> Option> { match timeout(Duration::from_millis(ms), self.next()).await { Ok(res) => res, Err(_) => None, } } } impl Encrypt { pub fn new(key: Key) -> Self { Self(key, 0, 0) } pub fn dec(&mut self, bytes: &mut BytesMut) -> Result<(), Error> { if bytes.len() <= 1 { return Ok(()); } self.2 += 1; let nonce = get_nonce(self.2); match secretbox::open(bytes, &nonce, &self.0) { Ok(res) => { bytes.clear(); bytes.put_slice(&res); Ok(()) } Err(()) => Err(Error::new(ErrorKind::Other, "decryption error")), } } pub fn enc(&mut self, data: &[u8]) -> Vec { self.1 += 1; let nonce = get_nonce(self.1); secretbox::seal(data, &nonce, &self.0) } } fn get_nonce(seqnum: u64) -> Nonce { let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]); nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes()); nonce }