diff --git a/Cargo.lock b/Cargo.lock index 8077c21..45ce31e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,6 +195,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base32" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "022dfe9eb35f19ebbcb51e0b40a5ab759f46ad60cadf7297e0bd085afb50e076" + [[package]] name = "base64" version = "0.13.0" @@ -414,6 +420,12 @@ dependencies = [ "toml 0.5.9", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation" version = "0.9.4" @@ -567,12 +579,13 @@ dependencies = [ [[package]] name = "digest" -version = "0.10.3" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -668,6 +681,22 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "email-encoding" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a87260449b06739ee78d6281c68d2a0ff3e3af64a78df63d3a1aeb3c06997c8a" +dependencies = [ + "base64 0.22.1", + "memchr", +] + +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1087,6 +1116,7 @@ dependencies = [ "ipnetwork", "jsonwebtoken", "lazy_static", + "lettre", "local-ip-address", "mac_address", "machine-uid 0.2.0", @@ -1101,7 +1131,10 @@ dependencies = [ "serde_json", "sodiumoxide", "sqlx", + "tokio", "tokio-tungstenite", + "toml 0.7.8", + "totp-rs", "tower-http", "tungstenite", "uuid", @@ -1163,6 +1196,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "0.2.7" @@ -1294,6 +1336,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.8.1" @@ -1438,6 +1490,33 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" +[[package]] +name = "lettre" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76bd09637ae3ec7bd605b8e135e757980b3968430ff2b1a4a94fb7769e50166d" +dependencies = [ + "async-trait", + "base64 0.21.7", + "email-encoding", + "email_address", + "fastrand", + "futures-io", + "futures-util", + "httpdate", + "idna 0.3.0", + "mime", + "nom 7.1.1", + "once_cell", + "quoted_printable", + "rustls 0.21.12", + "rustls-pemfile 1.0.0", + "socket2 0.4.4", + "tokio", + "tokio-rustls 0.24.1", + "webpki-roots 0.23.1", +] + [[package]] name = "lexical-core" version = "0.7.6" @@ -2083,6 +2162,12 @@ dependencies = [ "proc-macro2 1.0.93", ] +[[package]] +name = "quoted_printable" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3866219251662ec3b26fc217e3e05bf9c4f84325234dfb96bf0bf840889e49" + [[package]] name = "rand" version = "0.8.5" @@ -2404,6 +2489,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" +[[package]] +name = "rustls-webpki" +version = "0.100.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6a5fc258f1c1276dfe3016516945546e2d5383911efc0fc4f1cdc5df3a4ae3" +dependencies = [ + "ring 0.16.20", + "untrusted 0.7.1", +] + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -2558,6 +2653,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.2" @@ -3141,6 +3247,19 @@ dependencies = [ "winnow", ] +[[package]] +name = "totp-rs" +version = "5.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b36a9dd327e9f401320a2cb4572cc76ff43742bcfc3291f871691050f140ba" +dependencies = [ + "base32", + "constant_time_eq", + "hmac", + "sha1", + "sha2", +] + [[package]] name = "tower" version = "0.4.12" @@ -3338,7 +3457,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c" dependencies = [ "form_urlencoded", - "idna", + "idna 0.2.3", "matches", "percent-encoding", ] @@ -3503,6 +3622,15 @@ dependencies = [ "webpki", ] +[[package]] +name = "webpki-roots" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" +dependencies = [ + "rustls-webpki 0.100.3", +] + [[package]] name = "webpki-roots" version = "0.25.4" diff --git a/Cargo.toml b/Cargo.toml index 68ee6d6..f5972f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,10 @@ path = "src/utils.rs" [dependencies] hbb_common = { path = "libs/hbb_common" } +tokio = { version = "1", features = ["fs", "io-util"] } +totp-rs = { version = "5.4", default-features = false } +lettre = { version = "0.10", default-features = false, features = ["smtp-transport", "tokio1-rustls-tls", "builder"] } +toml = "0.7" serde_derive = "1.0" serde = "1.0" serde_json = "1.0" diff --git a/src/api/ab/legacy.rs b/src/api/ab/legacy.rs new file mode 100644 index 0000000..a624c9c --- /dev/null +++ b/src/api/ab/legacy.rs @@ -0,0 +1,161 @@ +//! Legacy single-blob address book — `GET /api/ab` and `POST /api/ab`. +//! +//! Activated when the operator sets `--ab-legacy-mode=on` (which makes +//! `/api/ab/personal` 404 — the documented signal in CONSOLE_API.md §4.2). +//! The wire shape is a JSON-string field `data` whose contents are a second +//! JSON object: `{tags, peers, tag_colors}`. We translate to/from the +//! normalized M2 schema on the personal AB. + +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::state::AppState; +use crate::database::AbPeerRow; +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::Json; +use serde_json::{json, Map, Value}; +use std::sync::Arc; + +pub async fn get( + Extension(state): Extension>, + user: AuthedUser, +) -> Result, ApiError> { + let guid = state + .db + .ab_get_or_create_personal(user.user_id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + // Pull all peers and all tags. Page size 1000 is fine — legacy clients + // expected a single blob anyway. + let (_total, peers) = state + .db + .ab_list_peers(&guid, 0, 10_000) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let tags = state + .db + .ab_list_tags(&guid) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let mut tag_colors = Map::new(); + let tag_names: Vec<&str> = tags.iter().map(|t| t.name.as_str()).collect(); + for t in &tags { + tag_colors.insert(t.name.clone(), Value::from(t.color)); + } + let peer_arr: Vec = peers + .iter() + .map(|p| { + json!({ + "id": p.id, + "alias": p.alias, + "tags": p.tags, + "username": p.username, + "hostname": p.hostname, + "platform": p.platform, + "hash": p.hash, + }) + }) + .collect(); + let inner = json!({ + "tags": tag_names, + "peers": peer_arr, + "tag_colors": Value::String(serde_json::to_string(&tag_colors).unwrap_or_default()), + }); + Ok(Json(json!({ "data": inner.to_string() }))) +} + +#[derive(serde::Deserialize)] +pub struct LegacyPostBody { + pub data: String, +} + +pub async fn put( + Extension(state): Extension>, + user: AuthedUser, + Json(body): Json, +) -> Result { + let guid = state + .db + .ab_get_or_create_personal(user.user_id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let inner: Value = serde_json::from_str(&body.data) + .map_err(|e| ApiError::BadRequest(format!("data is not valid json: {}", e)))?; + // Tag colors are stored as a JSON-encoded string field (Flutter wraps + // the map in another JSON layer). Tolerate either an inline map or the + // doubly-encoded form. + let tag_colors_map: Map = match inner.get("tag_colors") { + Some(Value::String(s)) => serde_json::from_str(s).unwrap_or_default(), + Some(Value::Object(m)) => m.clone(), + _ => Map::new(), + }; + let tag_names: Vec = inner + .get("tags") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }) + .unwrap_or_default(); + let tags: Vec<(String, i64)> = tag_names + .iter() + .map(|n| { + let color = tag_colors_map + .get(n) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + (n.clone(), color) + }) + .collect(); + + let peer_arr = inner + .get("peers") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + let mut peers: Vec = Vec::with_capacity(peer_arr.len()); + for p in peer_arr { + let id = p + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + if id.is_empty() { + continue; + } + let tags = p + .get("tags") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }) + .unwrap_or_default(); + peers.push(AbPeerRow { + id, + alias: get_str(&p, "alias"), + note: String::new(), + password: String::new(), + hash: get_str(&p, "hash"), + username: get_str(&p, "username"), + hostname: get_str(&p, "hostname"), + platform: get_str(&p, "platform"), + tags, + }); + } + state + .db + .ab_legacy_replace(&guid, &tags, &peers) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} + +fn get_str(v: &Value, k: &str) -> String { + v.get(k) + .and_then(|x| x.as_str()) + .unwrap_or_default() + .to_string() +} diff --git a/src/api/ab/mod.rs b/src/api/ab/mod.rs new file mode 100644 index 0000000..90d5bc1 --- /dev/null +++ b/src/api/ab/mod.rs @@ -0,0 +1,6 @@ +pub mod legacy; +pub mod peers; +pub mod profiles; +pub mod rules; +pub mod settings; +pub mod tags; diff --git a/src/api/ab/peers.rs b/src/api/ab/peers.rs new file mode 100644 index 0000000..dae6521 --- /dev/null +++ b/src/api/ab/peers.rs @@ -0,0 +1,198 @@ +use crate::api::ab::rules::{enforce, Rule}; +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::pagination::Page; +use crate::api::state::AppState; +use crate::database::AbPeerInsert; +use axum::extract::{Extension, Path, Query}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::Json; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::sync::Arc; + +/// `serde_urlencoded` (axum's query decoder) does not honour +/// `#[serde(flatten)]`, so the pagination fields are spelled out inline. +#[derive(Debug, Deserialize)] +pub struct AbQuery { + /// guid sent in the query string for `/api/ab/peers?ab=`. + pub ab: String, + #[serde(default = "default_current")] + pub current: i64, + #[serde(default = "default_page_size", rename = "pageSize")] + pub page_size: i64, +} + +fn default_current() -> i64 { + 1 +} +fn default_page_size() -> i64 { + 100 +} + +impl AbQuery { + fn offset(&self) -> i64 { + (self.current.max(1) - 1) * self.limit() + } + fn limit(&self) -> i64 { + self.page_size.clamp(1, 1000) + } +} + +/// `POST /api/ab/peers?ab=` — paginated peer list inside an AB. +/// Wire shape matches the Flutter `Peer` decoder; only fields documented in +/// CONSOLE_API.md §4.4 are surfaced. +#[derive(Debug, Serialize)] +struct PeerOut { + id: String, + alias: String, + tags: Vec, + note: String, + #[serde(skip_serializing_if = "String::is_empty")] + password: String, + #[serde(skip_serializing_if = "String::is_empty")] + hash: String, + #[serde(skip_serializing_if = "String::is_empty")] + username: String, + #[serde(skip_serializing_if = "String::is_empty")] + hostname: String, + #[serde(skip_serializing_if = "String::is_empty")] + platform: String, +} + +pub async fn list( + Extension(state): Extension>, + user: AuthedUser, + Query(q): Query, +) -> Result { + enforce(&state, user.user_id, &q.ab, Rule::Read).await?; + let (total, rows) = state + .db + .ab_list_peers(&q.ab, q.offset(), q.limit()) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let data: Vec = rows + .into_iter() + .map(|r| PeerOut { + id: r.id, + alias: r.alias, + tags: r.tags, + note: r.note, + password: r.password, + hash: r.hash, + username: r.username, + hostname: r.hostname, + platform: r.platform, + }) + .collect(); + Ok((StatusCode::OK, Json(Page { total, data }))) +} + +#[derive(Debug, Deserialize)] +pub struct PeerAddBody { + pub id: String, + #[serde(default)] + pub alias: Option, + #[serde(default)] + pub tags: Option>, + #[serde(default)] + pub note: Option, + #[serde(default)] + pub password: Option, + #[serde(default)] + pub hash: Option, + #[serde(default)] + pub username: Option, + #[serde(default)] + pub hostname: Option, + #[serde(default)] + pub platform: Option, +} + +/// `POST /api/ab/peer/add/{guid}` — insert one peer. **Returns HTTP 200 +/// with an empty body on success**, or `{"error":"..."}` JSON body on failure +/// (also HTTP 200). The Flutter `_jsonDecodeActionResp` at +/// flutter/lib/models/ab_model.dart:2002 treats *any* non-empty success body +/// as an error to surface — including `{}` (which produces the literal string +/// "null"), so action endpoints must reply with truly empty bodies. +pub async fn add( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(body): Json, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + if body.id.is_empty() { + return Err(ApiError::BadRequest("id required".into())); + } + let max = state.cfg.ab_max_peers_per_book; + let count = state + .db + .ab_count_peers(&guid) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + if count >= max { + return Err(ApiError::Forbidden("exceed_max_devices".into())); + } + state + .db + .ab_peer_insert( + &guid, + AbPeerInsert { + id: &body.id, + alias: body.alias.as_deref(), + note: body.note.as_deref(), + password: body.password.as_deref(), + hash: body.hash.as_deref(), + username: body.username.as_deref(), + hostname: body.hostname.as_deref(), + platform: body.platform.as_deref(), + }, + body.tags.as_deref(), + ) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} + +/// `PUT /api/ab/peer/update/{guid}` — partial update. Body always carries +/// `id`, plus any subset of mutable fields. Empty success body, see `add`. +pub async fn update( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(body): Json, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + let id = body + .get("id") + .and_then(|v| v.as_str()) + .ok_or_else(|| ApiError::BadRequest("id required".into()))?; + let updated = state + .db + .ab_peer_partial_update(&guid, id, &body) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + if !updated { + return Err(ApiError::Forbidden("peer not found".into())); + } + Ok(StatusCode::OK) +} + +/// `DELETE /api/ab/peer/{guid}` — body is a JSON array of peer IDs. Empty +/// success body, see `add`. +pub async fn delete( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(ids): Json>, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + state + .db + .ab_peers_delete(&guid, &ids) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} diff --git a/src/api/ab/profiles.rs b/src/api/ab/profiles.rs new file mode 100644 index 0000000..c3c28a3 --- /dev/null +++ b/src/api/ab/profiles.rs @@ -0,0 +1,71 @@ +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::pagination::{Page, PageQuery}; +use crate::api::state::AppState; +use axum::extract::{Extension, Query}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::Json; +use serde::Serialize; +use serde_json::{json, Value}; +use std::sync::Arc; + +/// `POST /api/ab/personal` — returns the caller's personal AB GUID, creating +/// it if missing. When `--ab-legacy-mode=on` is configured, returns 404 to +/// signal "this server speaks the legacy single-blob protocol" (the client +/// then falls back to GET/POST /api/ab). +pub async fn personal( + Extension(state): Extension>, + user: AuthedUser, +) -> Result, ApiError> { + if state.cfg.ab_legacy_mode { + return Err(ApiError::NotFound); + } + let guid = state + .db + .ab_get_or_create_personal(user.user_id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(Json(json!({ "guid": guid }))) +} + +/// `POST /api/ab/shared/profiles` — paginated list of shared address books +/// the caller can see. Wire shape matches the Flutter `AbProfile` decoder at +/// flutter/lib/common/hbbs/hbbs.dart:258. +#[derive(Debug, Serialize)] +struct AbProfileOut { + guid: String, + name: String, + owner: String, + note: String, + rule: i64, + #[serde(skip_serializing_if = "Option::is_none")] + info: Option, +} + +pub async fn shared_profiles( + Extension(state): Extension>, + user: AuthedUser, + Query(page): Query, +) -> Result { + let (total, rows) = state + .db + .ab_list_shared_for_user(user.user_id, page.offset(), page.limit()) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let data = rows + .into_iter() + .map(|r| AbProfileOut { + guid: r.guid, + name: r.name, + owner: r.owner, + note: r.note, + rule: r.rule, + info: r + .info_json + .as_deref() + .and_then(|s| serde_json::from_str(s).ok()), + }) + .collect(); + Ok((StatusCode::OK, Json(Page { total, data }))) +} diff --git a/src/api/ab/rules.rs b/src/api/ab/rules.rs new file mode 100644 index 0000000..9e01d63 --- /dev/null +++ b/src/api/ab/rules.rs @@ -0,0 +1,49 @@ +use crate::api::error::ApiError; +use crate::api::state::AppState; + +/// Share-rule levels for a shared address book. Wire integers match the +/// Flutter client's `ShareRule` enum at flutter/lib/common/hbbs/hbbs.dart:210. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum Rule { + Read = 1, + ReadWrite = 2, + Full = 3, +} + +impl Rule { + pub fn from_i64(v: i64) -> Option { + match v { + 1 => Some(Rule::Read), + 2 => Some(Rule::ReadWrite), + 3 => Some(Rule::Full), + _ => None, + } + } +} + +/// Enforce that `caller` has at least `needed` access on `ab_guid`. Used at +/// the top of every AB handler. Resolution lives in +/// `Database::ab_resolve_rule` and considers (a) AB ownership and (b) the +/// largest matching rule across direct and device-group shares. +pub async fn enforce( + state: &AppState, + caller_user_id: i64, + ab_guid: &str, + needed: Rule, +) -> Result<(), ApiError> { + let resolved = state + .db + .ab_resolve_rule(caller_user_id, ab_guid) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let Some(have) = resolved.and_then(Rule::from_i64) else { + // Either the AB doesn't exist or the caller has no relationship with + // it. Surface as "not allowed" so we don't leak existence. + return Err(ApiError::Forbidden("not allowed".into())); + }; + if have >= needed { + Ok(()) + } else { + Err(ApiError::Forbidden("read-only".into())) + } +} diff --git a/src/api/ab/settings.rs b/src/api/ab/settings.rs new file mode 100644 index 0000000..09e8a3b --- /dev/null +++ b/src/api/ab/settings.rs @@ -0,0 +1,16 @@ +use crate::api::middleware::AuthedUser; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::Json; +use serde_json::{json, Value}; +use std::sync::Arc; + +/// `POST /api/ab/settings` — capability/limit probe. The Flutter client +/// (ab_model.dart:230-258) calls this once per pull cycle to learn +/// `max_peer_one_ab`. Auth is required even though there is no body. +pub async fn settings( + Extension(state): Extension>, + _user: AuthedUser, +) -> Json { + Json(json!({ "max_peer_one_ab": state.cfg.ab_max_peers_per_book })) +} diff --git a/src/api/ab/tags.rs b/src/api/ab/tags.rs new file mode 100644 index 0000000..3a3a891 --- /dev/null +++ b/src/api/ab/tags.rs @@ -0,0 +1,122 @@ +use crate::api::ab::rules::{enforce, Rule}; +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::state::AppState; +use axum::extract::{Extension, Path}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::Json; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// `POST /api/ab/tags/{guid}` — list tags. Wire shape is a bare JSON array +/// `[{name, color}]`, NOT the `Page` envelope. +#[derive(Debug, Serialize)] +struct TagOut { + name: String, + color: i64, +} + +pub async fn list( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::Read).await?; + let rows = state + .db + .ab_list_tags(&guid) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let out: Vec = rows + .into_iter() + .map(|t| TagOut { + name: t.name, + color: t.color, + }) + .collect(); + Ok((StatusCode::OK, Json(out))) +} + +#[derive(Debug, Deserialize)] +pub struct TagAddBody { + pub name: String, + pub color: i64, +} + +pub async fn add( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(body): Json, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + if body.name.is_empty() { + return Err(ApiError::BadRequest("name required".into())); + } + state + .db + .ab_tag_insert(&guid, &body.name, body.color) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} + +#[derive(Debug, Deserialize)] +pub struct TagRenameBody { + #[serde(rename = "old")] + pub old_name: String, + #[serde(rename = "new")] + pub new_name: String, +} + +pub async fn rename( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(body): Json, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + state + .db + .ab_tag_rename(&guid, &body.old_name, &body.new_name) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} + +#[derive(Debug, Deserialize)] +pub struct TagUpdateBody { + pub name: String, + pub color: i64, +} + +pub async fn update( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(body): Json, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + state + .db + .ab_tag_update_color(&guid, &body.name, body.color) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} + +pub async fn delete( + Extension(state): Extension>, + user: AuthedUser, + Path(guid): Path, + Json(names): Json>, +) -> Result { + enforce(&state, user.user_id, &guid, Rule::ReadWrite).await?; + state + .db + .ab_tags_delete(&guid, &names) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} diff --git a/src/api/audit/alarm.rs b/src/api/audit/alarm.rs new file mode 100644 index 0000000..2f6c36c --- /dev/null +++ b/src/api/audit/alarm.rs @@ -0,0 +1,38 @@ +//! `POST /api/audit/alarm` — security alarm (IP whitelist hit, brute-force +//! thresholds). Wire shape from CONSOLE_API.md §7.3: +//! `{ id, uuid, typ: int, info: stringified-JSON }`. + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::Json; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct AlarmAuditBody { + #[serde(default)] + pub id: String, + #[serde(default)] + pub uuid: String, + #[serde(default)] + pub typ: i64, + #[serde(default)] + pub info: String, +} + +pub async fn alarm( + Extension(state): Extension>, + Json(body): Json, +) -> Result { + if body.id.is_empty() { + return Err(ApiError::BadRequest("id required".into())); + } + state + .db + .audit_alarm_insert(&body.id, body.typ, &body.info) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} diff --git a/src/api/audit/conn.rs b/src/api/audit/conn.rs new file mode 100644 index 0000000..aee702e --- /dev/null +++ b/src/api/audit/conn.rs @@ -0,0 +1,49 @@ +//! `POST /api/audit/conn` — fire-and-forget connection log entry. The client +//! ([src/server/connection.rs:1248-1279](file:///Users/sn0/Desktop/rustdesk/src/server/connection.rs#L1248)) +//! emits this on every accepted session, no Authorization header. We answer +//! with `{"guid":"..."}` so the client can pass that guid back later in +//! `PUT /api/audit` (CONSOLE_API.md §7.1). + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::Json; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct ConnAuditBody { + #[serde(default)] + pub id: String, + #[serde(default)] + pub uuid: String, + #[serde(default)] + pub conn_id: i64, + #[serde(default)] + pub session_id: i64, + #[serde(default)] + pub ip: String, + #[serde(default)] + pub action: String, +} + +pub async fn conn( + Extension(state): Extension>, + Json(body): Json, +) -> Result, ApiError> { + if body.id.is_empty() { + return Err(ApiError::BadRequest("id required".into())); + } + let action = if body.action.is_empty() { + "new" + } else { + body.action.as_str() + }; + let guid = state + .db + .audit_conn_insert(&body.id, body.conn_id, body.session_id, &body.ip, action) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(Json(json!({ "guid": guid }))) +} diff --git a/src/api/audit/file.rs b/src/api/audit/file.rs new file mode 100644 index 0000000..6b7d305 --- /dev/null +++ b/src/api/audit/file.rs @@ -0,0 +1,50 @@ +//! `POST /api/audit/file` — file transfer log entry (CONSOLE_API.md §7.2). +//! `info` arrives as a stringified JSON object; we store it verbatim. + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::Json; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct FileAuditBody { + #[serde(default)] + pub id: String, + #[serde(default)] + pub uuid: String, + #[serde(default)] + pub peer_id: String, + #[serde(default, rename = "type")] + pub direction: i64, + #[serde(default)] + pub path: String, + #[serde(default)] + pub is_file: bool, + #[serde(default)] + pub info: String, +} + +pub async fn file( + Extension(state): Extension>, + Json(body): Json, +) -> Result { + if body.id.is_empty() { + return Err(ApiError::BadRequest("id required".into())); + } + state + .db + .audit_file_insert( + &body.id, + &body.peer_id, + body.direction, + &body.path, + body.is_file, + &body.info, + ) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(StatusCode::OK) +} diff --git a/src/api/audit/mod.rs b/src/api/audit/mod.rs new file mode 100644 index 0000000..a027e2a --- /dev/null +++ b/src/api/audit/mod.rs @@ -0,0 +1,4 @@ +pub mod alarm; +pub mod conn; +pub mod file; +pub mod note; diff --git a/src/api/audit/note.rs b/src/api/audit/note.rs new file mode 100644 index 0000000..51bc8c8 --- /dev/null +++ b/src/api/audit/note.rs @@ -0,0 +1,39 @@ +//! `PUT /api/audit` — operator end-of-session note. Sent from the Flutter +//! `_showConnEndAuditDialogCloseCanceled` flow at +//! [flutter/lib/common/widgets/dialog.dart:1656](file:///Users/sn0/Desktop/rustdesk/flutter/lib/common/widgets/dialog.dart#L1656). +//! Bearer-authenticated. + +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::Json; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct NoteBody { + pub guid: String, + #[serde(default)] + pub note: String, +} + +pub async fn note( + Extension(state): Extension>, + _user: AuthedUser, + Json(body): Json, +) -> Result { + if body.guid.is_empty() { + return Err(ApiError::BadRequest("guid required".into())); + } + let updated = state + .db + .audit_conn_update_note(&body.guid, &body.note) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + if !updated { + return Err(ApiError::NotFound); + } + Ok(StatusCode::OK) +} diff --git a/src/api/auth.rs b/src/api/auth.rs new file mode 100644 index 0000000..85a6add --- /dev/null +++ b/src/api/auth.rs @@ -0,0 +1,361 @@ +use crate::api::email; +use crate::api::error::ApiError; +use crate::api::middleware::{sha256_token, AuthedUser}; +use crate::api::state::AppState; +use crate::api::users::{verify_password, UserPayload}; +use crate::database::UserRow; +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::Json; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::sync::Arc; +use totp_rs::{Algorithm, Secret, TOTP}; + +const EMAIL_CODE_TTL_SECS: i64 = 600; + +/// `LoginRequest` mirrors the Flutter client at +/// flutter/lib/common/hbbs/hbbs.dart:133. M1 only consults `username`, +/// `password`, and `type`; the other fields are tolerated for forward-compat. +#[derive(Debug, Deserialize)] +pub struct LoginRequest { + #[serde(default)] + pub username: Option, + #[serde(default)] + pub password: Option, + #[serde(default)] + pub id: Option, + #[serde(default)] + pub uuid: Option, + #[serde(default, rename = "type")] + pub kind: Option, + #[serde(default, rename = "deviceInfo")] + pub device_info: Option, + // Tolerated, ignored in M1: + #[serde(default)] + pub auto_login: Option, + #[serde(default, rename = "verificationCode")] + pub verification_code: Option, + #[serde(default, rename = "tfaCode")] + pub tfa_code: Option, + #[serde(default)] + pub secret: Option, +} + +#[derive(Debug, Deserialize)] +pub struct IdUuidBody { + #[serde(default)] + pub id: Option, + #[serde(default)] + pub uuid: Option, +} + +pub async fn login_options_head() -> StatusCode { + StatusCode::OK +} + +pub async fn login_options(Extension(state): Extension>) -> Json> { + // Static base set from config (account / email_code), plus a dynamic + // `oidc/` entry per enabled provider in the DB. Recomputed per + // request so adding a provider via SQL takes effect without a restart. + let mut out = state.cfg.login_options.clone(); + if !state.cfg.public_base_url.is_empty() { + if let Ok(providers) = state.db.oidc_provider_list_enabled().await { + for p in providers { + out.push(format!("oidc/{}", p.name)); + } + } + } + Json(out) +} + +const TFA_CHALLENGE_TTL_SECS: i64 = 300; + +pub async fn login( + Extension(state): Extension>, + Json(req): Json, +) -> Result, ApiError> { + // Branch on `type`. Empty / "account" is the password path; "tfa_code" + // is the second leg of a TOTP challenge issued earlier in this same + // dance. Reject anything else for now — M4 will add email_code etc. + let kind = req.kind.as_deref().unwrap_or("account"); + match kind { + "account" | "" => login_account(state, req).await, + "tfa_code" => login_tfa_code(state, req).await, + "email_code" => login_email_code(state, req).await, + other => Err(ApiError::BadRequest(format!( + "unsupported login type: {}", + other + ))), + } +} + +/// Two-leg passwordless login by email. Leg 1 (no `verificationCode`) mints a +/// fresh 6-digit code and emails it to the user (or logs to stdout when SMTP +/// is unconfigured). Leg 2 (with `verificationCode`) verifies the code, +/// consumes it, and issues an access token. +async fn login_email_code( + state: Arc, + req: LoginRequest, +) -> Result, ApiError> { + // The Flutter client passes the email/username in the `username` field; + // accept it either as a literal email or as a username we can map to one. + let identifier = req + .username + .as_deref() + .filter(|s| !s.is_empty()) + .ok_or_else(|| ApiError::BadRequest("username (email) required".into()))?; + let user = resolve_user_by_identifier(&state, identifier).await?; + let email = if !user.email.is_empty() { + user.email.clone() + } else if user.username.contains('@') { + // Operator bootstraps users with email-as-username — accept that. + user.username.clone() + } else { + return Err(ApiError::BadRequest( + "user has no email address on file".into(), + )); + }; + + if let Some(code) = req + .verification_code + .as_deref() + .filter(|s| !s.is_empty()) + { + // Leg 2: verify. + let supplied_hash = sodiumoxide::crypto::hash::sha256::hash(code.as_bytes()) + .as_ref() + .to_vec(); + let ok = state + .db + .email_code_verify(&email, &supplied_hash) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + if !ok { + return Err(ApiError::BadCredentials); + } + if user.status == 0 { + return Err(ApiError::AccountDisabled); + } + if user.status == -1 { + return Err(ApiError::Unverified); + } + return issue_session(&state, &req, &user).await; + } + + // Leg 1: mint + send a fresh code. + let (code, code_hash) = email::mint_code(); + state + .db + .email_code_create(&email, &code_hash, EMAIL_CODE_TTL_SECS) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + if let Err(e) = email::send_login_code(state.cfg.email.as_ref(), &email, &code).await { + return Err(ApiError::Internal(e)); + } + Ok(Json(json!({ "type": "email_check" }))) +} + +async fn resolve_user_by_identifier( + state: &AppState, + identifier: &str, +) -> Result { + if identifier.contains('@') { + if let Some(u) = state + .db + .user_find_by_email(identifier) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + { + return Ok(u); + } + } + state + .db + .user_find_by_username(identifier) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or(ApiError::BadCredentials) +} + +async fn login_account( + state: Arc, + req: LoginRequest, +) -> Result, ApiError> { + let username = req + .username + .as_deref() + .filter(|s| !s.is_empty()) + .ok_or_else(|| ApiError::BadRequest("username required".into()))?; + let password = req + .password + .as_deref() + .filter(|s| !s.is_empty()) + .ok_or_else(|| ApiError::BadRequest("password required".into()))?; + + let user = state + .db + .user_find_by_username(username) + .await? + .ok_or(ApiError::BadCredentials)?; + + let ok = verify_password(user.password_hash.clone(), password.to_string()).await?; + if !ok { + return Err(ApiError::BadCredentials); + } + if user.status == 0 { + return Err(ApiError::AccountDisabled); + } + if user.status == -1 { + return Err(ApiError::Unverified); + } + + // 2FA gate: if the user has TOTP enrolled, mint a short-lived nonce and + // tell the client we want the TOTP code in a follow-up POST. The client + // echoes the nonce back as `secret`. + if state.db.totp_get_secret(user.id).await?.is_some() { + let nonce = state + .db + .tfa_challenge_create(user.id, TFA_CHALLENGE_TTL_SECS) + .await?; + return Ok(Json(json!({ + "type": "tfa_check", + "tfa_type": "totp", + "secret": nonce, + }))); + } + + issue_session(&state, &req, &user).await +} + +async fn login_tfa_code( + state: Arc, + req: LoginRequest, +) -> Result, ApiError> { + let nonce = req + .secret + .as_deref() + .filter(|s| !s.is_empty()) + .ok_or_else(|| ApiError::BadRequest("secret required".into()))?; + let code = req + .tfa_code + .as_deref() + .filter(|s| !s.is_empty()) + .ok_or_else(|| ApiError::BadRequest("tfaCode required".into()))?; + + let user_id = state + .db + .tfa_challenge_lookup(nonce) + .await? + .ok_or_else(|| ApiError::BadRequest("invalid or expired challenge".into()))?; + let secret_b32 = state + .db + .totp_get_secret(user_id) + .await? + .ok_or_else(|| ApiError::BadRequest("TOTP not enrolled".into()))?; + + if !verify_totp(&secret_b32, code)? { + // Leave the challenge row alive — operators may want short retries. + return Err(ApiError::BadCredentials); + } + state.db.tfa_challenge_consume(nonce).await?; + + let user = state + .db + .user_find_by_id(user_id) + .await? + .ok_or(ApiError::Unauthorized)?; + issue_session(&state, &req, &user).await +} + +/// Build and persist a fresh access token, claim the calling device, and +/// return the standard logged-in response shape. Shared by the password, +/// post-TOTP, post-email-code, and (later) post-OIDC paths. +async fn issue_session( + state: &AppState, + req: &LoginRequest, + user: &UserRow, +) -> Result, ApiError> { + let token = mint_token(); + let sha = sha256_token(&token); + let device_info_json = req + .device_info + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_default(); + state + .db + .token_insert( + user.id, + &sha, + req.id.as_deref().unwrap_or_default(), + req.uuid.as_deref().unwrap_or_default(), + &device_info_json, + state.cfg.session_ttl_secs, + ) + .await?; + // Bind the calling device to this user so /api/peers shows it correctly. + state + .db + .device_claim( + user.id, + req.id.as_deref().unwrap_or_default(), + req.uuid.as_deref().unwrap_or_default(), + ) + .await; + + Ok(Json(json!({ + "access_token": token, + "type": "access_token", + "user": UserPayload::from(user), + }))) +} + +fn verify_totp(secret_b32: &str, code: &str) -> Result { + let secret = Secret::Encoded(secret_b32.to_string()) + .to_bytes() + .map_err(|e| ApiError::Internal(format!("bad TOTP secret: {:?}", e)))?; + let totp = TOTP::new(Algorithm::SHA1, 6, 1, 30, secret) + .map_err(|e| ApiError::Internal(format!("TOTP init: {}", e)))?; + totp.check_current(code) + .map_err(|e| ApiError::Internal(format!("TOTP check: {}", e))) +} + +pub async fn current_user( + Extension(state): Extension>, + user: AuthedUser, + // Body is required by the client but its content is purely advisory. + Json(_body): Json, +) -> Result, ApiError> { + let row = state + .db + .user_find_by_id(user.user_id) + .await? + .ok_or(ApiError::Unauthorized)?; + Ok(Json(UserPayload::from(&row))) +} + +pub async fn logout( + Extension(state): Extension>, + headers: axum::http::HeaderMap, + Json(_body): Json, +) -> StatusCode { + // Best-effort: parse the bearer ourselves so a missing/invalid token still + // returns 200 (matches the client's fire-and-forget logout flow). + if let Some(auth) = headers.get(axum::http::header::AUTHORIZATION) { + if let Ok(s) = auth.to_str() { + if let Some(tok) = s.strip_prefix("Bearer ").map(str::trim) { + if !tok.is_empty() { + let sha = sha256_token(tok); + let _ = state.db.token_delete(&sha).await; + } + } + } + } + StatusCode::OK +} + +pub(crate) fn mint_token() -> String { + let bytes = sodiumoxide::randombytes::randombytes(32); + base64::encode_config(bytes, base64::URL_SAFE_NO_PAD) +} diff --git a/src/api/devices_cli.rs b/src/api/devices_cli.rs new file mode 100644 index 0000000..bf47d04 --- /dev/null +++ b/src/api/devices_cli.rs @@ -0,0 +1,198 @@ +//! `POST /api/devices/cli` — used by `rustdesk --assign --token ...` +//! to enroll a freshly installed device into a tenant slot. +//! +//! Per CONSOLE_API.md §11: bearer-authenticated; the response body is plain +//! text (empty = success, non-empty = informational message). The client +//! prints "Done!" when the body is empty. + +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::state::AppState; +use crate::database::AbPeerInsert; +use axum::extract::Extension; +use axum::http::header; +use axum::response::IntoResponse; +use axum::Json; +use serde::Deserialize; +use serde_json::Value; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct AssignBody { + pub id: String, + pub uuid: String, + #[serde(default)] + pub user_name: Option, + #[serde(default)] + pub strategy_name: Option, + #[serde(default)] + pub address_book_name: Option, + #[serde(default)] + pub address_book_tag: Option, + #[serde(default)] + pub address_book_alias: Option, + #[serde(default)] + pub address_book_password: Option, + #[serde(default)] + pub address_book_note: Option, + #[serde(default)] + pub device_group_name: Option, + #[serde(default)] + pub note: Option, + #[serde(default)] + pub device_username: Option, + #[serde(default)] + pub device_name: Option, +} + +pub async fn assign( + Extension(state): Extension>, + caller: AuthedUser, + Json(body): Json, +) -> Result { + if body.id.is_empty() || body.uuid.is_empty() { + return Err(ApiError::BadRequest("id and uuid required".into())); + } + let mut warnings: Vec = vec![]; + + // Resolve owner. If --user_name was supplied, that's the owner; otherwise + // the caller becomes the owner (matches `rustdesk --assign` flows where + // the operator's account is the destination). + let owner = if let Some(name) = body.user_name.as_deref().filter(|s| !s.is_empty()) { + if !caller.is_admin { + return Err(ApiError::Forbidden( + "admin required to assign to another user".into(), + )); + } + match state + .db + .user_find_by_username(name) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + { + Some(u) => u, + None => { + return Err(ApiError::BadRequest(format!( + "no such user: {}", + name + ))); + } + } + } else { + state + .db + .user_find_by_id(caller.user_id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or(ApiError::Unauthorized)? + }; + + // Bind the device to the owner (mirrors what /api/login's device_claim + // does, but here it's an admin operation rather than user-initiated). + state.db.device_claim(owner.id, &body.id, &body.uuid).await; + + // Address-book entry. We always target the *owner's* personal AB. + if let Some(ab_name) = body.address_book_name.as_deref().filter(|s| !s.is_empty()) { + let _ = ab_name; // M2's get_or_create_personal ignores the name; OSS has one personal AB per user. + let ab_guid = state + .db + .ab_get_or_create_personal(owner.id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let tags: Option> = body + .address_book_tag + .as_deref() + .filter(|s| !s.is_empty()) + .map(|t| t.split(',').map(|s| s.trim().to_string()).collect()); + if let Err(e) = state + .db + .ab_peer_insert( + &ab_guid, + AbPeerInsert { + id: &body.id, + alias: body.address_book_alias.as_deref(), + note: body.address_book_note.as_deref(), + password: body.address_book_password.as_deref(), + hash: None, + username: body.device_username.as_deref(), + hostname: body.device_name.as_deref(), + platform: None, + }, + tags.as_deref(), + ) + .await + { + // Likely a UNIQUE conflict if the peer is already in the AB; + // surface as a warning rather than failing the whole call. + warnings.push(format!("address-book entry not added: {}", e)); + } + } + + // Strategy assignment by name. We attach to the device directly (peer-scoped), + // which is the most-specific tier in our resolver. + if let Some(name) = body.strategy_name.as_deref().filter(|s| !s.is_empty()) { + match resolve_strategy_id(&state, name).await? { + Some(strategy_id) => { + if let Err(e) = state + .db + .strategy_assign_peer(strategy_id, &body.id) + .await + { + warnings.push(format!("strategy assignment failed: {}", e)); + } + } + None => { + warnings.push(format!("strategy {:?} does not exist", name)); + } + } + } + + // Device-group membership: ensure the group exists, ensure the owner is a + // member. We treat the group name as the natural key per the M2 schema. + if let Some(group_name) = body.device_group_name.as_deref().filter(|s| !s.is_empty()) { + if let Err(e) = state + .db + .device_group_ensure_member(group_name, owner.id) + .await + { + warnings.push(format!("device-group assignment failed: {}", e)); + } + } + + // Fields we accept but don't currently persist as discrete columns. These + // travel with the next sysinfo upload anyway (note, device_username, + // device_name end up in `device_sysinfo.payload` JSON). + if body.note.as_deref().map(|s| !s.is_empty()).unwrap_or(false) { + warnings.push( + "--note is currently surfaced via sysinfo only, not persisted as a discrete field" + .into(), + ); + } + + let body_text = if warnings.is_empty() { + String::new() + } else { + warnings.join("\n") + }; + Ok(( + [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], + body_text, + )) +} + +async fn resolve_strategy_id( + state: &AppState, + name: &str, +) -> Result, ApiError> { + state + .db + .strategy_find_by_name(name) + .await + .map_err(|e| ApiError::Internal(e.to_string())) +} + +/// Wrap the `Value` JSON the request _could_ have under `Json` if a +/// future variation needs it. Currently unused; kept for symmetry with other +/// modules that work with raw JSON in/out. +#[allow(dead_code)] +fn ignore_value(_v: Value) {} diff --git a/src/api/email.rs b/src/api/email.rs new file mode 100644 index 0000000..06c15c2 --- /dev/null +++ b/src/api/email.rs @@ -0,0 +1,80 @@ +//! SMTP transport for email-code login. Two modes: +//! +//! - **Production:** `--smtp-host` (and friends) configured → real SMTP via +//! `lettre` with optional STARTTLS + auth. +//! - **Dev:** `--smtp-host` empty → the code is logged to stdout instead. +//! This makes the round-trip testable without standing up a mail server. + +use crate::api::state::EmailConfig; +use hbb_common::log; +use lettre::message::header::ContentType; +use lettre::transport::smtp::authentication::Credentials; +use lettre::transport::smtp::AsyncSmtpTransport; +use lettre::{AsyncTransport, Message, Tokio1Executor}; + +pub async fn send_login_code( + cfg: Option<&EmailConfig>, + to: &str, + code: &str, +) -> Result<(), String> { + if to.is_empty() { + return Err("recipient address is empty".into()); + } + let Some(cfg) = cfg else { + // Dev mode: surface the code so the operator can complete the flow. + log::info!("[email-code] login code for <{}>: {}", to, code); + return Ok(()); + }; + let body = format!( + "Your login code is: {}\n\nIt expires in 10 minutes.\nIf you didn't request this, ignore this email.\n", + code + ); + let message = Message::builder() + .from( + cfg.from + .parse() + .map_err(|e| format!("invalid From address {:?}: {}", cfg.from, e))?, + ) + .to(to.parse().map_err(|e| format!("invalid To address {:?}: {}", to, e))?) + .subject("Your RustDesk login code") + .header(ContentType::TEXT_PLAIN) + .body(body) + .map_err(|e| format!("compose: {}", e))?; + + let mut builder = if cfg.starttls { + AsyncSmtpTransport::::starttls_relay(&cfg.host) + .map_err(|e| format!("STARTTLS init for {}: {}", cfg.host, e))? + } else { + AsyncSmtpTransport::::builder_dangerous(&cfg.host) + } + .port(cfg.port); + if let (Some(user), Some(pass)) = (cfg.username.as_deref(), cfg.password.as_deref()) { + builder = builder.credentials(Credentials::new(user.to_string(), pass.to_string())); + } + let transport = builder.build(); + transport + .send(message) + .await + .map_err(|e| format!("smtp send to {}: {}", cfg.host, e))?; + log::info!("[email-code] code mailed to <{}>", to); + Ok(()) +} + +/// Generate a 6-digit numeric code with cryptographic entropy. Returns the +/// code as a string and its sha256 for storage. +pub fn mint_code() -> (String, Vec) { + // Sample 4 random bytes, fold into 0..1_000_000, format as 6-digit + // zero-padded decimal. 24 bits of entropy is plenty for a 10-minute + // 5-attempt-limit code. + let bytes = sodiumoxide::randombytes::randombytes(4); + let mut n: u32 = 0; + for b in &bytes { + n = (n << 8) | (*b as u32); + } + let n = n % 1_000_000; + let code = format!("{:06}", n); + let hash = sodiumoxide::crypto::hash::sha256::hash(code.as_bytes()) + .as_ref() + .to_vec(); + (code, hash) +} diff --git a/src/api/error.rs b/src/api/error.rs new file mode 100644 index 0000000..c548d4f --- /dev/null +++ b/src/api/error.rs @@ -0,0 +1,55 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use serde_json::json; + +/// Single error type for the management API. Always serializes to +/// `{"error":"..."}` per the protocol spec; the HTTP status is chosen so the +/// client behaves correctly: +/// +/// - 401 Unauthorized clears the local access_token (intentional fallback in +/// the Flutter client — see CONSOLE_API.md §3.6). +/// - 200 OK + JSON `error` for business failures (bad creds, validation). +/// Most non-auth handlers should return BadRequest or Conflict instead so +/// the operator can distinguish them in logs. +#[derive(Debug)] +pub enum ApiError { + Unauthorized, + BadCredentials, + AccountDisabled, + Unverified, + Forbidden(String), + NotFound, + BadRequest(String), + Internal(String), +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let (status, msg) = match self { + ApiError::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized".to_string()), + ApiError::BadCredentials => (StatusCode::UNAUTHORIZED, "bad credentials".to_string()), + ApiError::AccountDisabled => (StatusCode::FORBIDDEN, "account disabled".to_string()), + ApiError::Unverified => (StatusCode::FORBIDDEN, "unverified".to_string()), + // Returning HTTP 200 + {"error": ...} for share-rule rejections. + // Flutter's _jsonDecodeActionResp at ab_model.dart:2002 surfaces + // the JSON `error` field as a toast and stays signed-in; using + // 403 here would trigger the global 401/403 logout path and yank + // the user's session. + ApiError::Forbidden(m) => (StatusCode::OK, m), + ApiError::NotFound => (StatusCode::NOT_FOUND, "not found".to_string()), + ApiError::BadRequest(m) => (StatusCode::BAD_REQUEST, m), + ApiError::Internal(m) => { + hbb_common::log::error!("api internal error: {}", m); + (StatusCode::OK, "internal error".to_string()) + } + }; + (status, Json(json!({ "error": msg }))).into_response() + } +} + +impl From for ApiError { + fn from(e: hbb_common::anyhow::Error) -> Self { + ApiError::Internal(e.to_string()) + } +} diff --git a/src/api/groups.rs b/src/api/groups.rs new file mode 100644 index 0000000..337d7aa --- /dev/null +++ b/src/api/groups.rs @@ -0,0 +1,37 @@ +//! `GET /api/device-group/accessible` — paginated list of device groups the +//! caller is a member of (admin sees all). The Flutter client at +//! flutter/lib/models/group_model.dart:103 silently tolerates errors here, so +//! we keep the behavior tight: empty list when no groups exist, never panic. + +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::pagination::{Page, PageQuery}; +use crate::api::state::AppState; +use axum::extract::{Extension, Query}; +use axum::Json; +use serde::Serialize; +use std::sync::Arc; + +#[derive(Debug, Serialize)] +pub struct DeviceGroupOut { + pub name: String, +} + +pub async fn accessible( + Extension(state): Extension>, + user: AuthedUser, + Query(q): Query, +) -> Result>, ApiError> { + let (total, rows) = state + .db + .groups_list_for_user(user.user_id, user.is_admin, q.offset(), q.limit()) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(Json(Page { + total, + data: rows + .into_iter() + .map(|g| DeviceGroupOut { name: g.name }) + .collect(), + })) +} diff --git a/src/api/heartbeat.rs b/src/api/heartbeat.rs new file mode 100644 index 0000000..e3cce90 --- /dev/null +++ b/src/api/heartbeat.rs @@ -0,0 +1,99 @@ +//! `POST /api/heartbeat` — the agent management loop. The client sends every +//! ~15 s (3 s when active connections exist). The reply may carry, in any +//! combination: +//! - `sysinfo: true` — force the client to re-upload sysinfo immediately, +//! - `disconnect: [conn_id, ...]` — tell the client to drop those sessions, +//! - `modified_at` + `strategy` — push a config-options merge. +//! +//! Auth: none (the client identifies the device by `(id, uuid)` body fields). + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use crate::api::strategy; +use axum::extract::Extension; +use axum::Json; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct HeartbeatBody { + #[serde(default)] + pub id: String, + #[serde(default)] + pub uuid: String, + #[serde(default)] + pub ver: i64, + #[serde(default)] + pub conns: Option>, + #[serde(default)] + pub modified_at: i64, +} + +#[derive(Debug, Serialize)] +pub struct HeartbeatResp { + /// Present-and-truthy → client re-uploads sysinfo immediately. + #[serde(skip_serializing_if = "Option::is_none")] + pub sysinfo: Option, + /// Conn IDs the client should drop. Always present (possibly empty). + pub disconnect: Vec, + /// Strategy version. Echoed back by the client; when it changes, the + /// client re-merges `strategy.config_options` into local config. + pub modified_at: i64, + pub strategy: Value, +} + +pub async fn heartbeat( + Extension(state): Extension>, + Json(body): Json, +) -> Result, ApiError> { + if body.id.is_empty() || body.uuid.is_empty() { + return Err(ApiError::BadRequest("id and uuid required".into())); + } + let conns_json = serde_json::to_string(&body.conns.unwrap_or_default()) + .unwrap_or_else(|_| "[]".into()); + + let needs_sysinfo = state + .db + .sysinfo_heartbeat( + &body.id, + &body.uuid, + body.ver, + &conns_json, + &state.cfg.sysinfo_ver, + ) + .await?; + + // One-shot operator commands queued for this peer (force-disconnect, + // force-sysinfo). Read-and-delete in one transaction. + let mut disconnect: Vec = vec![]; + let mut force_sysinfo = needs_sysinfo; + for cmd in state + .db + .heartbeat_pop_commands(&body.id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + { + match cmd.kind.as_str() { + "disconnect" => { + if let Some(payload) = cmd.payload { + if let Ok(arr) = serde_json::from_str::>(&payload) { + disconnect.extend(arr); + } + } + } + "sysinfo" => force_sysinfo = true, + other => hbb_common::log::warn!("unknown heartbeat_command kind {:?}", other), + } + } + + // Strategy resolution (peer > device-group > user, highest priority wins). + let (modified_at, strategy) = strategy::resolve_for(&state, &body.id).await; + + Ok(Json(HeartbeatResp { + sysinfo: if force_sysinfo { Some(true) } else { None }, + disconnect, + modified_at, + strategy, + })) +} diff --git a/src/api/middleware.rs b/src/api/middleware.rs new file mode 100644 index 0000000..497b6b8 --- /dev/null +++ b/src/api/middleware.rs @@ -0,0 +1,59 @@ +use crate::api::error::ApiError; +use crate::api::state::AppState; +use async_trait::async_trait; +use axum::extract::{FromRequest, RequestParts, TypedHeader}; +use axum::headers::{authorization::Bearer, Authorization}; +use std::sync::Arc; + +pub struct AuthedUser { + pub user_id: i64, + pub name: String, + pub is_admin: bool, +} + +pub fn sha256_token(token: &str) -> Vec { + sodiumoxide::crypto::hash::sha256::hash(token.as_bytes()) + .as_ref() + .to_vec() +} + +#[async_trait] +impl FromRequest for AuthedUser { + type Rejection = ApiError; + + async fn from_request(req: &mut RequestParts) -> Result { + let bearer: TypedHeader> = + TypedHeader::from_request(req).await.map_err(|_| ApiError::Unauthorized)?; + let state: axum::extract::Extension> = + axum::extract::Extension::from_request(req) + .await + .map_err(|_| ApiError::Internal("missing state".into()))?; + let token = bearer.0 .0.token().to_string(); + let sha = sha256_token(&token); + + let (user_id, _exp) = state + .db + .token_lookup(&sha) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or(ApiError::Unauthorized)?; + + // Slide the expiry forward on every authenticated request. + if let Err(e) = state.db.token_touch(&sha, state.cfg.session_ttl_secs).await { + hbb_common::log::warn!("token_touch failed: {}", e); + } + + let user = state + .db + .user_find_by_id(user_id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or(ApiError::Unauthorized)?; + + Ok(Self { + user_id: user.id, + name: user.username, + is_admin: user.is_admin, + }) + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..0948250 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,103 @@ +//! HTTP management API mounted in-process alongside hbbs's rendezvous +//! listeners. The router is wired in via `src/rendezvous_server.rs`'s outer +//! `tokio::select!`. M1 covers auth + heartbeat + sysinfo; later milestones +//! add address book, audit, OIDC, etc. + +pub mod ab; +pub mod audit; +pub mod auth; +pub mod devices_cli; +pub mod email; +pub mod error; +pub mod groups; +pub mod heartbeat; +pub mod middleware; +pub mod oidc; +pub mod pagination; +pub mod peers; +pub mod plugin_sign; +pub mod record; +pub mod state; +pub mod strategy; +pub mod sysinfo; +pub mod twofa; +pub mod users; + +pub use state::AppState; + +use axum::extract::Extension; +use axum::routing::{delete, get, post, put}; +use axum::Router; +use hbb_common::{log, ResultType}; +use std::net::SocketAddr; +use std::sync::Arc; + +pub fn router(state: Arc) -> Router { + Router::new() + // M1: auth + heartbeat + sysinfo + .route( + "/api/login-options", + get(auth::login_options).head(auth::login_options_head), + ) + .route("/api/login", post(auth::login)) + .route("/api/currentUser", post(auth::current_user)) + .route("/api/logout", post(auth::logout)) + .route("/api/heartbeat", post(heartbeat::heartbeat)) + .route("/api/sysinfo_ver", post(sysinfo::sysinfo_ver)) + .route("/api/sysinfo", post(sysinfo::sysinfo)) + // M2: address book — modern (shared + personal) + .route("/api/ab/settings", post(ab::settings::settings)) + .route("/api/ab/personal", post(ab::profiles::personal)) + .route( + "/api/ab/shared/profiles", + post(ab::profiles::shared_profiles), + ) + .route("/api/ab/peers", post(ab::peers::list)) + .route("/api/ab/tags/:guid", post(ab::tags::list)) + .route("/api/ab/peer/add/:guid", post(ab::peers::add)) + .route("/api/ab/peer/update/:guid", put(ab::peers::update)) + .route("/api/ab/peer/:guid", delete(ab::peers::delete)) + .route("/api/ab/tag/add/:guid", post(ab::tags::add)) + .route("/api/ab/tag/rename/:guid", put(ab::tags::rename)) + .route("/api/ab/tag/update/:guid", put(ab::tags::update)) + .route("/api/ab/tag/:guid", delete(ab::tags::delete)) + // M2: address book — legacy single-blob fallback + .route( + "/api/ab", + get(ab::legacy::get).post(ab::legacy::put), + ) + // M2: group / users / peers panel + .route( + "/api/device-group/accessible", + get(groups::accessible), + ) + .route("/api/users", get(users::list)) + .route("/api/peers", get(peers::list)) + // M3: audit + .route("/api/audit/conn", post(audit::conn::conn)) + .route("/api/audit/file", post(audit::file::file)) + .route("/api/audit/alarm", post(audit::alarm::alarm)) + .route("/api/audit", put(audit::note::note)) + // M3: session recording upload + .route("/api/record", post(record::record)) + // M4: TOTP enrollment (admin-only) + .route("/api/2fa/enroll", post(twofa::enroll)) + .route("/api/2fa/unenroll", post(twofa::unenroll)) + // M4: rustdesk --assign target + .route("/api/devices/cli", post(devices_cli::assign)) + // M4: plugin signing (no auth — protocol-level) + .route("/lic/web/api/plugin-sign", post(plugin_sign::plugin_sign)) + // M4: OIDC device-flow login + .route("/api/oidc/auth", post(oidc::auth::auth)) + .route("/api/oidc/auth-query", get(oidc::poll::auth_query)) + .route("/oidc/callback", get(oidc::callback::callback)) + .layer(Extension(state)) +} + +pub async fn serve(addr: SocketAddr, state: Arc) -> ResultType<()> { + log::info!("HTTP API listening on {}", addr); + axum::Server::bind(&addr) + .serve(router(state).into_make_service()) + .await?; + Ok(()) +} diff --git a/src/api/oidc/auth.rs b/src/api/oidc/auth.rs new file mode 100644 index 0000000..05da482 --- /dev/null +++ b/src/api/oidc/auth.rs @@ -0,0 +1,99 @@ +//! `POST /api/oidc/auth` — start the device-flow login. + +use crate::api::error::ApiError; +use crate::api::oidc::{discovery, random_token, require_provider, OIDC_SESSION_TTL_SECS}; +use crate::api::state::AppState; +use crate::database::OidcSessionInsert; +use axum::extract::Extension; +use axum::Json; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct AuthBody { + /// Provider short-name from `oidc_providers.name`. The Flutter client + /// sends this from the `op` field of the OIDC dialog. + #[serde(default)] + pub op: String, + #[serde(default)] + pub id: String, + #[serde(default)] + pub uuid: String, + #[serde(default, rename = "deviceInfo")] + pub device_info: Option, +} + +pub async fn auth( + Extension(state): Extension>, + Json(body): Json, +) -> Result, ApiError> { + if state.cfg.public_base_url.is_empty() { + return Err(ApiError::Internal( + "OIDC requires --public-base-url to be set".into(), + )); + } + if body.op.is_empty() { + return Err(ApiError::BadRequest("op (provider name) required".into())); + } + let provider = require_provider(&state, &body.op).await?; + let disc = discovery::discover(&provider.issuer_url) + .await + .map_err(ApiError::Internal)?; + + let code = random_token(); + let csrf_state = random_token(); + let device_info_json = body + .device_info + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "{}".to_string()); + + let expires_at = chrono::Utc::now().timestamp() + OIDC_SESSION_TTL_SECS; + state + .db + .oidc_session_create(&OidcSessionInsert { + code: &code, + provider: &provider.name, + state: &csrf_state, + client_id_str: &body.id, + client_uuid: &body.uuid, + device_info_json: &device_info_json, + expires_at, + }) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + // Build the IdP authorization URL. + let url = format!( + "{auth}?response_type=code&client_id={cid}&redirect_uri={ru}&scope={scope}&state={state}", + auth = disc.authorization_endpoint, + cid = url_encode(&provider.client_id), + ru = url_encode(&provider.redirect_url), + scope = url_encode(&provider.scopes), + state = url_encode(&csrf_state), + ); + + Ok(Json(json!({ + "code": code, + "url": url, + }))) +} + +/// Inline percent-encoder for the auth URL query string. See +/// `api::twofa::url_encode` for the same routine. +fn url_encode(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for b in s.as_bytes() { + match b { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + out.push(*b as char); + } + _ => { + use std::fmt::Write; + let _ = write!(out, "%{:02X}", b); + } + } + } + out +} diff --git a/src/api/oidc/callback.rs b/src/api/oidc/callback.rs new file mode 100644 index 0000000..a483092 --- /dev/null +++ b/src/api/oidc/callback.rs @@ -0,0 +1,191 @@ +//! `GET /oidc/callback?code=&state=` — browser-facing redirect target. +//! +//! After the user signs in at the IdP, the IdP redirects their browser +//! here. We exchange the IdP code for tokens, fetch userinfo, find/create +//! a local user, mint our access token, and mark the session `success`. +//! The browser sees a small "you can close this window" page; the desktop +//! client picks up the token via `/api/oidc/auth-query`. + +use crate::api::auth::mint_token; +use crate::api::middleware::sha256_token; +use crate::api::oidc::{discovery, require_provider}; +use crate::api::state::AppState; +use axum::extract::{Extension, Query}; +use axum::response::Html; +use serde::Deserialize; +use serde_json::Value; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct CallbackQuery { + #[serde(default)] + pub code: String, + #[serde(default)] + pub state: String, + /// Some IdPs forward an error here on failed auth (e.g. user clicked + /// "deny"). We surface it as the session error and as a friendly page. + #[serde(default)] + pub error: Option, + #[serde(default)] + pub error_description: Option, +} + +pub async fn callback( + Extension(state): Extension>, + Query(q): Query, +) -> Html { + match handle(state, q).await { + Ok(()) => Html(html_page( + "Sign-in complete", + "You can close this window and return to RustDesk.", + )), + Err(msg) => Html(html_page("Sign-in failed", &html_escape(&msg))), + } +} + +async fn handle(state: Arc, q: CallbackQuery) -> Result<(), String> { + if q.state.is_empty() { + return Err("missing state parameter".into()); + } + let session = state + .db + .oidc_session_get_by_state(&q.state) + .await + .map_err(|e| e.to_string())? + .ok_or_else(|| "unknown or expired oidc session (state)".to_string())?; + + if let Some(err) = q.error.as_deref().filter(|s| !s.is_empty()) { + let detail = q + .error_description + .as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or(err); + let _ = state + .db + .oidc_session_fail(&session.code, &format!("idp: {}", detail)) + .await; + return Err(format!("identity provider returned an error: {}", detail)); + } + if q.code.is_empty() { + return Err("missing authorization code".into()); + } + + let provider = require_provider(&state, &session.provider) + .await + .map_err(|e| format!("{:?}", e))?; + let disc = discovery::discover(&provider.issuer_url).await?; + + // Token exchange. + let token_body = match discovery::http_post_form( + &disc.token_endpoint, + &[ + ("grant_type", "authorization_code"), + ("code", &q.code), + ("redirect_uri", &provider.redirect_url), + ("client_id", &provider.client_id), + ("client_secret", &provider.client_secret), + ], + ) + .await + { + Ok(b) => b, + Err(e) => { + let _ = state + .db + .oidc_session_fail(&session.code, &format!("token exchange: {}", e)) + .await; + return Err(e); + } + }; + let token_resp: Value = + serde_json::from_str(&token_body).map_err(|e| format!("parse token resp: {}", e))?; + let access_token = token_resp + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| "token response missing access_token".to_string())?; + + // Fetch userinfo. We trust the userinfo endpoint as the authority on + // the user's identity (sub + optional email + name). + let userinfo_url = disc + .userinfo_endpoint + .as_deref() + .ok_or_else(|| "provider has no userinfo_endpoint".to_string())?; + let userinfo_body = discovery::http_get_with_bearer(userinfo_url, access_token).await?; + let userinfo: Value = serde_json::from_str(&userinfo_body) + .map_err(|e| format!("parse userinfo: {}", e))?; + let sub = userinfo + .get("sub") + .and_then(|v| v.as_str()) + .ok_or_else(|| "userinfo missing sub".to_string())?; + let email = userinfo.get("email").and_then(|v| v.as_str()); + let display_name = userinfo + .get("name") + .and_then(|v| v.as_str()) + .or_else(|| userinfo.get("preferred_username").and_then(|v| v.as_str())); + + let user = state + .db + .user_upsert_oidc(sub, email, display_name) + .await + .map_err(|e| e.to_string())?; + if user.status == 0 { + return Err("user is disabled".into()); + } + + // Mint our own access token, store hashed, mark session complete. + let token = mint_token(); + let sha = sha256_token(&token); + state + .db + .token_insert( + user.id, + &sha, + &session.client_id_str, + &session.client_uuid, + &session.device_info_json, + state.cfg.session_ttl_secs, + ) + .await + .map_err(|e| e.to_string())?; + // Best-effort device claim — same path as `/api/login`. + state + .db + .device_claim(user.id, &session.client_id_str, &session.client_uuid) + .await; + + state + .db + .oidc_session_complete(&session.code, &token, user.id) + .await + .map_err(|e| e.to_string())?; + Ok(()) +} + +fn html_page(title: &str, body: &str) -> String { + format!( + r#" +{title} + +
+

{title}

+

{body}

+
"#, + title = title, + body = body + ) +} + +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") +} diff --git a/src/api/oidc/discovery.rs b/src/api/oidc/discovery.rs new file mode 100644 index 0000000..9c88b2d --- /dev/null +++ b/src/api/oidc/discovery.rs @@ -0,0 +1,128 @@ +//! `/.well-known/openid-configuration` discovery + in-memory cache. +//! +//! Most OIDC providers serve a JSON document at this URL describing the +//! authorization, token, and userinfo endpoints. Doing discovery once per +//! provider and caching the result keeps the per-login overhead to two +//! HTTP calls (token exchange + userinfo). + +use hbb_common::log; +use once_cell::sync::Lazy; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::Mutex; + +#[derive(Debug, Clone, Deserialize)] +pub struct OidcDiscovery { + pub authorization_endpoint: String, + pub token_endpoint: String, + #[serde(default)] + pub userinfo_endpoint: Option, + #[serde(default)] + pub issuer: Option, +} + +static CACHE: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +/// Fetch (or return cached) discovery document for `issuer_url`. Strips a +/// trailing `/` so the cache key is stable across operator typos. +pub async fn discover(issuer_url: &str) -> Result { + let issuer = issuer_url.trim_end_matches('/').to_string(); + if let Some(d) = CACHE.lock().unwrap().get(&issuer).cloned() { + return Ok(d); + } + let url = format!("{}/.well-known/openid-configuration", issuer); + log::info!("oidc: discovering {}", url); + let body = http_get(&url).await?; + let parsed: OidcDiscovery = serde_json::from_str(&body) + .map_err(|e| format!("discovery parse {}: {}", url, e))?; + CACHE.lock().unwrap().insert(issuer, parsed.clone()); + Ok(parsed) +} + +/// Blocking HTTP GET wrapped in `spawn_blocking`. We use the existing +/// `reqwest::blocking::Client` rather than adding an async client, because +/// (a) discovery happens at most once per provider and (b) the rustdesk +/// reqwest fork is configured for blocking-only use throughout the server. +pub async fn http_get(url: &str) -> Result { + let url = url.to_owned(); + hbb_common::tokio::task::spawn_blocking(move || { + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(15)) + .build() + .map_err(|e| format!("http client build: {}", e))?; + let resp = client + .get(&url) + .send() + .map_err(|e| format!("http get {}: {}", url, e))?; + let status = resp.status(); + let body = resp.text().map_err(|e| format!("read body: {}", e))?; + if !status.is_success() { + return Err(format!("http {} -> {}: {}", url, status, body)); + } + Ok(body) + }) + .await + .map_err(|e| format!("spawn_blocking: {}", e))? +} + +pub async fn http_post_form( + url: &str, + form: &[(&str, &str)], +) -> Result { + let url = url.to_owned(); + let owned: Vec<(String, String)> = form + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + hbb_common::tokio::task::spawn_blocking(move || { + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(15)) + .build() + .map_err(|e| format!("http client build: {}", e))?; + let pairs: Vec<(&str, &str)> = owned + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + let resp = client + .post(&url) + .form(&pairs) + .send() + .map_err(|e| format!("http post {}: {}", url, e))?; + let status = resp.status(); + let body = resp.text().map_err(|e| format!("read body: {}", e))?; + if !status.is_success() { + return Err(format!("http {} -> {}: {}", url, status, body)); + } + Ok(body) + }) + .await + .map_err(|e| format!("spawn_blocking: {}", e))? +} + +pub async fn http_get_with_bearer( + url: &str, + bearer: &str, +) -> Result { + let url = url.to_owned(); + let bearer = bearer.to_owned(); + hbb_common::tokio::task::spawn_blocking(move || { + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(15)) + .build() + .map_err(|e| format!("http client build: {}", e))?; + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {}", bearer)) + .send() + .map_err(|e| format!("http get {}: {}", url, e))?; + let status = resp.status(); + let body = resp.text().map_err(|e| format!("read body: {}", e))?; + if !status.is_success() { + return Err(format!("http {} -> {}: {}", url, status, body)); + } + Ok(body) + }) + .await + .map_err(|e| format!("spawn_blocking: {}", e))? +} diff --git a/src/api/oidc/mod.rs b/src/api/oidc/mod.rs new file mode 100644 index 0000000..81287c5 --- /dev/null +++ b/src/api/oidc/mod.rs @@ -0,0 +1,53 @@ +//! OIDC device-flow login. +//! +//! Wire flow (matching CONSOLE_API.md §3.5): +//! +//! 1. `POST /api/oidc/auth { op: , id, uuid, deviceInfo }` → +//! `{ code: , url: }`. The client +//! opens `url` in the user's browser. +//! 2. The IdP redirects the browser back to our `/oidc/callback?code=...&state=...`. +//! That handler exchanges the IdP code for a token, fetches userinfo, +//! upserts a local user, mints our own access token, and marks the +//! session `success`. +//! 3. The client polls `GET /api/oidc/auth-query?code=&id=&uuid=` until it +//! sees a wrapped `AuthBody` envelope. +//! +//! Auth on the IdP side is handled by the provider's standard OAuth2 +//! authorization-code flow. We keep the hbbs side minimal: discovery via +//! `/.well-known/openid-configuration`, no JWT verification (we +//! trust the userinfo endpoint, authenticated via the access token). + +pub mod auth; +pub mod callback; +pub mod discovery; +pub mod poll; +pub mod providers; + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use crate::database::OidcProviderRow; + +const OIDC_SESSION_TTL_SECS: i64 = 600; // 10 minutes — the user has to sign in fast + +/// Convenience: resolve a provider name to its row, or an ApiError if it +/// doesn't exist or is disabled. +pub(crate) async fn require_provider( + state: &AppState, + name: &str, +) -> Result { + state + .db + .oidc_provider_get(name) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or_else(|| ApiError::BadRequest(format!("unknown OIDC provider: {}", name))) +} + +/// 24 random bytes, base64url-encoded → ~32 characters. Used for both the +/// poll-handle (`code`) and the CSRF state. +pub(crate) fn random_token() -> String { + base64::encode_config( + sodiumoxide::randombytes::randombytes(24), + base64::URL_SAFE_NO_PAD, + ) +} diff --git a/src/api/oidc/poll.rs b/src/api/oidc/poll.rs new file mode 100644 index 0000000..1d18879 --- /dev/null +++ b/src/api/oidc/poll.rs @@ -0,0 +1,86 @@ +//! `GET /api/oidc/auth-query?code=&id=&uuid=` — client poll loop. +//! +//! The Flutter client (src/hbbs_http/account.rs) wraps the response in an +//! outer envelope where the `body` field is itself JSON. We mirror that: +//! +//! `{ "body": "" }` +//! +//! The inner JSON is one of: +//! - while pending: `{"error":"No authed oidc is found"}` — client keeps polling. +//! - on success: the standard AuthBody (`{access_token, type:"access_token", user}`). +//! - on error: `{"error":""}` — client surfaces and stops polling. + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use crate::api::users::UserPayload; +use axum::extract::{Extension, Query}; +use axum::Json; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct PollQuery { + pub code: String, + #[serde(default)] + pub id: String, + #[serde(default)] + pub uuid: String, +} + +pub async fn auth_query( + Extension(state): Extension>, + Query(q): Query, +) -> Result, ApiError> { + let now = chrono::Utc::now().timestamp(); + let session = state + .db + .oidc_session_get_by_code(&q.code) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or_else(|| ApiError::BadRequest("unknown oidc session".into()))?; + if session.expires_at <= now && session.status == "pending" { + // The client treats this as an ordinary "still pending" tick and + // gives up on its own timeout (180 s). + return Ok(wrap_inner(json!({"error": "No authed oidc is found"}))); + } + match session.status.as_str() { + "pending" => Ok(wrap_inner(json!({"error": "No authed oidc is found"}))), + "error" => { + let msg = session + .error + .clone() + .unwrap_or_else(|| "OIDC sign-in failed".to_string()); + Ok(wrap_inner(json!({ "error": msg }))) + } + "success" => { + let access_token = session + .access_token + .clone() + .ok_or_else(|| ApiError::Internal("success session missing token".into()))?; + let user_id = session + .user_id + .ok_or_else(|| ApiError::Internal("success session missing user_id".into()))?; + let user = state + .db + .user_find_by_id(user_id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or_else(|| ApiError::Internal("user vanished mid-flow".into()))?; + let body = json!({ + "access_token": access_token, + "type": "access_token", + "user": UserPayload::from(&user), + }); + Ok(wrap_inner(body)) + } + other => Err(ApiError::Internal(format!( + "unknown oidc status {:?}", + other + ))), + } +} + +fn wrap_inner(inner: Value) -> Json { + Json(json!({ "body": inner.to_string() })) +} diff --git a/src/api/oidc/providers.rs b/src/api/oidc/providers.rs new file mode 100644 index 0000000..2bccad4 --- /dev/null +++ b/src/api/oidc/providers.rs @@ -0,0 +1,94 @@ +//! Operator-supplied provider config. Reads a TOML file shaped like: +//! +//! ```toml +//! [[providers]] +//! name = "google" +//! display_name = "Google" +//! issuer_url = "https://accounts.google.com" +//! client_id = "" +//! client_secret = "" +//! scopes = "openid email profile" +//! ``` +//! +//! Each entry is upserted into the `oidc_providers` table at startup. +//! `redirect_url` is computed from `--public-base-url` + `/oidc/callback`. +//! +//! TOML parsing uses the existing `rust-ini` crate? — no, we'd need a TOML +//! parser. We already have `toml` transitively via several deps; pull it in +//! directly for clarity. + +use crate::database::{Database, OidcProviderRow}; +use hbb_common::log; +use serde::Deserialize; +use std::path::Path; + +#[derive(Debug, Deserialize)] +struct ProvidersFile { + #[serde(default)] + providers: Vec, +} + +#[derive(Debug, Deserialize)] +struct ProviderEntry { + name: String, + #[serde(default)] + display_name: Option, + #[serde(default)] + icon_url: Option, + issuer_url: String, + client_id: String, + client_secret: String, + #[serde(default = "default_scopes")] + scopes: String, + /// Optional override; defaults to `/oidc/callback`. + #[serde(default)] + redirect_url: Option, + #[serde(default = "default_true")] + enabled: bool, +} + +fn default_scopes() -> String { + "openid email profile".to_string() +} +fn default_true() -> bool { + true +} + +pub async fn load_from_file( + db: &Database, + path: &Path, + public_base_url: &str, +) -> Result { + let bytes = std::fs::read_to_string(path) + .map_err(|e| format!("read {}: {}", path.display(), e))?; + let parsed: ProvidersFile = + toml::from_str(&bytes).map_err(|e| format!("parse {}: {}", path.display(), e))?; + let mut count = 0; + for p in parsed.providers { + let redirect_url = p + .redirect_url + .clone() + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| { + let base = public_base_url.trim_end_matches('/'); + format!("{}/oidc/callback", base) + }); + let row = OidcProviderRow { + name: p.name.clone(), + display_name: p.display_name, + icon_url: p.icon_url, + issuer_url: p.issuer_url, + client_id: p.client_id, + client_secret: p.client_secret, + scopes: p.scopes, + redirect_url, + enabled: p.enabled, + }; + db.oidc_provider_upsert(&row) + .await + .map_err(|e| format!("upsert {}: {}", p.name, e))?; + count += 1; + log::info!("oidc: provider {:?} configured", p.name); + } + Ok(count) +} diff --git a/src/api/pagination.rs b/src/api/pagination.rs new file mode 100644 index 0000000..54a145c --- /dev/null +++ b/src/api/pagination.rs @@ -0,0 +1,38 @@ +use serde::{Deserialize, Serialize}; + +/// Query-string pagination for list endpoints. The Flutter client at +/// flutter/lib/models/ab_model.dart and group_model.dart sends +/// `?current=1&pageSize=100` against every paginated list. Field names are +/// spelled explicitly here — `serde(rename_all = "camelCase")` would also +/// rename `current`, which we don't want. +#[derive(Debug, Deserialize)] +pub struct PageQuery { + #[serde(default = "default_current")] + pub current: i64, + #[serde(default = "default_page_size", rename = "pageSize")] + pub page_size: i64, +} + +fn default_current() -> i64 { + 1 +} +fn default_page_size() -> i64 { + 100 +} + +impl PageQuery { + pub fn offset(&self) -> i64 { + let cur = self.current.max(1); + (cur - 1) * self.limit() + } + pub fn limit(&self) -> i64 { + self.page_size.clamp(1, 1000) + } +} + +/// Standard envelope: `{ total, data }`. +#[derive(Debug, Serialize)] +pub struct Page { + pub total: i64, + pub data: Vec, +} diff --git a/src/api/peers.rs b/src/api/peers.rs new file mode 100644 index 0000000..9633f12 --- /dev/null +++ b/src/api/peers.rs @@ -0,0 +1,66 @@ +//! `GET /api/peers` — paginated peer list for the Group tab in the desktop +//! client. Flutter decoder at flutter/lib/common/hbbs/hbbs.dart:77 expects +//! `{ id, user, user_name, device_group_name, note, status, info: {...} }` +//! per row. + +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::pagination::{Page, PageQuery}; +use crate::api::state::AppState; +use axum::extract::{Extension, Query}; +use axum::Json; +use serde::Serialize; +use serde_json::{json, Value}; +use std::sync::Arc; + +#[derive(Debug, Serialize)] +pub struct PeerOut { + pub id: String, + pub user: String, + pub user_name: String, + pub device_group_name: String, + pub note: String, + pub status: i64, + pub info: Value, +} + +pub async fn list( + Extension(state): Extension>, + user: AuthedUser, + Query(q): Query, +) -> Result>, ApiError> { + let (total, rows) = state + .db + .peers_list_accessible(user.user_id, user.is_admin, q.offset(), q.limit()) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + let data: Vec = rows + .into_iter() + .map(|r| { + // Trim the full sysinfo blob to what the client actually reads. + let parsed: Value = serde_json::from_str(&r.sysinfo_payload).unwrap_or(Value::Null); + let pick = |k: &str| -> String { + parsed + .get(k) + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string() + }; + let info = json!({ + "username": pick("username"), + "device_name": pick("hostname"), + "os": pick("os"), + }); + PeerOut { + id: r.id, + user: r.owner_username, + user_name: r.owner_display_name, + device_group_name: r.device_group_name, + note: r.note, + status: r.status, + info, + } + }) + .collect(); + Ok(Json(Page { total, data })) +} diff --git a/src/api/plugin_sign.rs b/src/api/plugin_sign.rs new file mode 100644 index 0000000..b810e41 --- /dev/null +++ b/src/api/plugin_sign.rs @@ -0,0 +1,59 @@ +//! `POST /lic/web/api/plugin-sign` — signs a plugin's status/heartbeat +//! payload with the server's Ed25519 secret. The client (plugin runtime, +//! src/plugin/callback_msg.rs:282-296) sends: +//! +//! `{ "plugin_id": "...", "version": "...", "msg": [u8, u8, ...] }` +//! +//! and expects: +//! +//! `{ "signed_msg": [u8, u8, ...] }` +//! +//! No Authorization header — the client opens this without a token. Auth +//! is implicit via the licence-key shared secret on the rest of the +//! deployment; we just sign whatever is asked. (Pro can additionally +//! validate the plugin against an allowlist; OSS just signs.) + +use crate::api::error::ApiError; +use axum::Json; +use serde::{Deserialize, Serialize}; +use sodiumoxide::crypto::sign; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct PluginSignReq { + #[serde(default)] + pub plugin_id: String, + #[serde(default)] + pub version: String, + pub msg: Vec, +} + +#[derive(Debug, Serialize)] +pub struct PluginSignResp { + pub signed_msg: Vec, +} + +/// The signing key is the same Ed25519 secret hbbs already uses for +/// rendezvous KeyExchange (`id_ed25519`). We pull it from the shared +/// `RendezvousServer.inner.sk` via the AppState — but `AppState` doesn't +/// hold it today, so this handler reads it directly from a process-wide +/// `OnceCell` populated at startup. (See `set_signing_key` below.) +pub async fn plugin_sign( + Json(req): Json, +) -> Result, ApiError> { + let sk = SIGNING_KEY + .get() + .ok_or_else(|| ApiError::Internal("plugin signing not configured".into()))?; + let signed = sign::sign(&req.msg, sk); + Ok(Json(PluginSignResp { signed_msg: signed })) +} + +use once_cell::sync::OnceCell; + +static SIGNING_KEY: OnceCell> = OnceCell::new(); + +/// Called once from `RendezvousServer::start` after the keypair is loaded. +/// A no-op if already set; the server will only ever have one Ed25519 key. +pub fn set_signing_key(sk: sign::SecretKey) { + let _ = SIGNING_KEY.set(Arc::new(sk)); +} diff --git a/src/api/record/mod.rs b/src/api/record/mod.rs new file mode 100644 index 0000000..28fceaf --- /dev/null +++ b/src/api/record/mod.rs @@ -0,0 +1,54 @@ +//! `POST /api/record?type={new|part|tail|remove}&file=&offset=&length=` +//! +//! No Authorization header — clients fire-and-forget. The wire flow is +//! defined in CONSOLE_API.md §8 and src/hbbs_http/record_upload.rs in the +//! client. We dispatch on `?type=` into the storage state machine. + +pub mod storage; + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use axum::body::Bytes; +use axum::extract::{Extension, Query}; +use axum::http::StatusCode; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct RecordQuery { + #[serde(rename = "type")] + pub kind: String, + pub file: String, + #[serde(default)] + pub offset: Option, + #[serde(default)] + pub length: Option, +} + +pub async fn record( + Extension(state): Extension>, + Query(q): Query, + body: Bytes, +) -> Result { + match q.kind.as_str() { + "new" => storage::handle_new(&state, &q.file, "").await?, + "part" => { + let offset = q.offset.unwrap_or(0); + let length = q.length.unwrap_or(body.len()); + storage::handle_part(&state, &q.file, offset, length, &body).await?; + } + "tail" => { + let offset = q.offset.unwrap_or(0); + let length = q.length.unwrap_or(body.len()); + storage::handle_tail(&state, &q.file, offset, length, &body).await?; + } + "remove" => storage::handle_remove(&state, &q.file).await?, + other => { + return Err(ApiError::BadRequest(format!( + "unknown record type {:?}", + other + ))); + } + } + Ok(StatusCode::OK) +} diff --git a/src/api/record/storage.rs b/src/api/record/storage.rs new file mode 100644 index 0000000..fb0ab2a --- /dev/null +++ b/src/api/record/storage.rs @@ -0,0 +1,147 @@ +//! On-disk file IO for `/api/record`. The wire flow lives in +//! [src/hbbs_http/record_upload.rs](file:///Users/sn0/Desktop/rustdesk/src/hbbs_http/record_upload.rs) +//! on the client side: the controller emits `?type=new` once, then a series +//! of `?type=part&offset=N&length=L` chunks, and finally a `?type=tail` +//! header rewrite at offset 0. We mirror that as a tiny state machine. + +use crate::api::error::ApiError; +use crate::api::state::AppState; +use std::path::{Component, Path, PathBuf}; +use tokio::fs::{File, OpenOptions}; +use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; + +const TAIL_MAX: usize = 1024; + +/// Reject any filename that contains a path separator or `..` component. +/// The client only ever sends a basename per +/// `record_upload.rs:118-122`, so anything else is suspicious. +pub fn sanitized_path(root: &Path, file: &str) -> Result { + if file.is_empty() { + return Err(ApiError::BadRequest("file required".into())); + } + let p = Path::new(file); + let mut comps = p.components(); + let only = comps.next(); + let extra = comps.next(); + match (only, extra) { + (Some(Component::Normal(name)), None) if !name.is_empty() => Ok(root.join(name)), + _ => Err(ApiError::BadRequest("invalid file name".into())), + } +} + +pub async fn handle_new( + state: &AppState, + file: &str, + peer_id: &str, +) -> Result<(), ApiError> { + let path = sanitized_path(&state.cfg.recording_dir, file)?; + if let Some(dir) = path.parent() { + tokio::fs::create_dir_all(dir) + .await + .map_err(|e| ApiError::Internal(format!("mkdir {}: {}", dir.display(), e)))?; + } + // Truncate (or create) the file. + OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .await + .map_err(|e| ApiError::Internal(format!("create {}: {}", path.display(), e)))?; + state + .db + .recording_new(peer_id, file) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(()) +} + +pub async fn handle_part( + state: &AppState, + file: &str, + offset: u64, + length: usize, + body: &[u8], +) -> Result<(), ApiError> { + if body.len() != length { + hbb_common::log::warn!( + "record part length mismatch: declared={} actual={}", + length, + body.len() + ); + } + let path = sanitized_path(&state.cfg.recording_dir, file)?; + let max = state.cfg.recording_max_size_bytes; + if max > 0 && offset.saturating_add(body.len() as u64) > max { + return Err(ApiError::Forbidden("recording size cap exceeded".into())); + } + let mut f: File = OpenOptions::new() + .write(true) + .create(true) + .open(&path) + .await + .map_err(|e| ApiError::Internal(format!("open {}: {}", path.display(), e)))?; + f.seek(SeekFrom::Start(offset)) + .await + .map_err(|e| ApiError::Internal(format!("seek: {}", e)))?; + f.write_all(body) + .await + .map_err(|e| ApiError::Internal(format!("write: {}", e)))?; + f.flush().await.ok(); + let new_size = offset + body.len() as u64; + state + .db + .recording_set_state(file, "recording", Some(new_size as i64), false) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(()) +} + +pub async fn handle_tail( + state: &AppState, + file: &str, + offset: u64, + length: usize, + body: &[u8], +) -> Result<(), ApiError> { + if offset != 0 { + return Err(ApiError::BadRequest("tail must be at offset 0".into())); + } + if length > TAIL_MAX || body.len() > TAIL_MAX { + return Err(ApiError::BadRequest("tail exceeds 1024 bytes".into())); + } + let path = sanitized_path(&state.cfg.recording_dir, file)?; + let mut f = OpenOptions::new() + .write(true) + .open(&path) + .await + .map_err(|e| ApiError::Internal(format!("open {}: {}", path.display(), e)))?; + f.seek(SeekFrom::Start(0)) + .await + .map_err(|e| ApiError::Internal(format!("seek: {}", e)))?; + f.write_all(body) + .await + .map_err(|e| ApiError::Internal(format!("write tail: {}", e)))?; + f.flush().await.ok(); + state + .db + .recording_set_state(file, "finished", None, true) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(()) +} + +pub async fn handle_remove(state: &AppState, file: &str) -> Result<(), ApiError> { + let path = sanitized_path(&state.cfg.recording_dir, file)?; + if let Err(e) = tokio::fs::remove_file(&path).await { + if e.kind() != std::io::ErrorKind::NotFound { + hbb_common::log::warn!("remove {}: {}", path.display(), e); + } + } + state + .db + .recording_delete(file) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(()) +} diff --git a/src/api/state.rs b/src/api/state.rs new file mode 100644 index 0000000..73bcfe8 --- /dev/null +++ b/src/api/state.rs @@ -0,0 +1,135 @@ +use crate::common::{get_arg, get_arg_or}; +use crate::database::Database; +use std::path::PathBuf; +use std::sync::Arc; + +#[derive(Clone)] +pub struct ApiConfig { + pub login_options: Vec, + pub sysinfo_ver: String, + pub session_ttl_secs: i64, + /// When true, `/api/ab/personal` returns 404, forcing the client into the + /// legacy single-blob AB path (`GET/POST /api/ab`). The default is the + /// modern shared-AB path. + pub ab_legacy_mode: bool, + /// Surfaced verbatim via `/api/ab/settings.max_peer_one_ab`. + pub ab_max_peers_per_book: i64, + /// On-disk root for `/api/record` uploads. Created on first use; one + /// subdirectory per peer-id under here. + pub recording_dir: PathBuf, + /// 0 means unlimited. + pub recording_max_size_bytes: u64, + /// 0 means no retention sweep. + pub audit_retention_days: i64, + /// SMTP transport for email-code login. `None` = dev mode: codes are + /// logged to stdout instead of mailed. + pub email: Option, + /// Externally reachable base URL of this server, e.g. for the OIDC + /// redirect_uri. Empty disables OIDC. + pub public_base_url: String, +} + +/// SMTP wiring for email-code login. +#[derive(Clone, Debug)] +pub struct EmailConfig { + pub host: String, + pub port: u16, + pub username: Option, + pub password: Option, + pub from: String, + pub starttls: bool, +} + +#[derive(Clone)] +pub struct AppState { + pub db: Database, + pub cfg: ApiConfig, +} + +impl AppState { + pub fn new(db: Database) -> Arc { + let ab_legacy_mode = matches!( + get_arg_or("ab-legacy-mode", "off".to_string()) + .to_ascii_lowercase() + .as_str(), + "on" | "y" | "yes" | "true" | "1" + ); + let ab_max_peers_per_book: i64 = get_arg_or("ab-max-peers-per-book", "100".to_string()) + .parse() + .unwrap_or(100); + let recording_dir = + PathBuf::from(get_arg_or("recording-dir", "./recordings".to_string())); + let recording_max_size_bytes: u64 = get_arg_or("recording-max-size-mb", "0".to_string()) + .parse::() + .unwrap_or(0) + .saturating_mul(1024 * 1024); + let audit_retention_days: i64 = get_arg_or("audit-retention-days", "0".to_string()) + .parse() + .unwrap_or(0); + let email = build_email_config(); + let public_base_url = get_arg("public-base-url"); + // login_options advertises every login method this server accepts. + // The Flutter client uses this to render the matching button on the + // sign-in dialog. `email_code` and `oidc/` are opt-in so a + // deployment without SMTP / OIDC doesn't dangle a broken button. + let mut login_options = vec!["account".to_string()]; + if email.is_some() || std::env::var("ALLOW_DEV_EMAIL_CODE").is_ok() { + login_options.push("email_code".to_string()); + } + // OIDC providers are mounted dynamically — actual provider names are + // appended later by the oidc::providers loader once the DB rows exist. + Arc::new(Self { + db, + cfg: ApiConfig { + login_options, + sysinfo_ver: "m1-1".to_string(), + session_ttl_secs: 30 * 86400, + ab_legacy_mode, + ab_max_peers_per_book, + recording_dir, + recording_max_size_bytes, + audit_retention_days, + email, + public_base_url, + }, + }) + } +} + +fn build_email_config() -> Option { + let host = get_arg("smtp-host"); + if host.is_empty() { + return None; + } + let port: u16 = get_arg_or("smtp-port", "587".to_string()) + .parse() + .unwrap_or(587); + let username = { + let u = get_arg("smtp-user"); + if u.is_empty() { None } else { Some(u) } + }; + let password = { + let p = get_arg("smtp-pass"); + if p.is_empty() { None } else { Some(p) } + }; + let from = { + let f = get_arg("smtp-from"); + if f.is_empty() { + format!("noreply@{}", host) + } else { + f + } + }; + let starttls = matches!( + get_arg_or("smtp-tls", "on".to_string()).to_ascii_lowercase().as_str(), + "on" | "y" | "yes" | "true" | "1" + ); + Some(EmailConfig { + host, + port, + username, + password, + from, + starttls, + }) +} diff --git a/src/api/strategy/mod.rs b/src/api/strategy/mod.rs new file mode 100644 index 0000000..1f63031 --- /dev/null +++ b/src/api/strategy/mod.rs @@ -0,0 +1,35 @@ +//! Strategy resolver for the heartbeat path. The actual SQL lives in +//! `Database::strategy_resolve_for` — this module exists to give the +//! heartbeat handler a stable import surface and to centralize how a +//! resolved strategy is converted into the wire-shape JSON the client +//! expects (`strategy.config_options` + `strategy.extra` per +//! CONSOLE_API.md §6.1). + +use crate::api::state::AppState; +use crate::database::ResolvedStrategy; +use serde_json::{json, Value}; + +/// Resolve and serialize a strategy for `peer_id`. Returns +/// `(modified_at, strategy_value)` where `strategy_value` is the JSON object +/// the heartbeat reply embeds under `strategy`. When no strategy applies, we +/// return an empty `{config_options: {}, extra: {}}` and `modified_at = 0`. +pub async fn resolve_for(state: &AppState, peer_id: &str) -> (i64, Value) { + let resolved = state + .db + .strategy_resolve_for(peer_id) + .await + .unwrap_or_default(); + serialize(&resolved) +} + +fn serialize(r: &ResolvedStrategy) -> (i64, Value) { + let cfg: Value = serde_json::from_str(&r.config_options_json).unwrap_or_else(|_| json!({})); + let extra: Value = serde_json::from_str(&r.extra_json).unwrap_or_else(|_| json!({})); + ( + r.modified_at, + json!({ + "config_options": cfg, + "extra": extra, + }), + ) +} diff --git a/src/api/sysinfo.rs b/src/api/sysinfo.rs new file mode 100644 index 0000000..2da8090 --- /dev/null +++ b/src/api/sysinfo.rs @@ -0,0 +1,61 @@ +use crate::api::error::ApiError; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::Json; +use serde_json::Value; +use std::sync::Arc; + +/// Plain-text version string that the client compares against its cached +/// `sysinfo_ver`. Same value the heartbeat handler echoes via the +/// `sysinfo: true` flag. +pub async fn sysinfo_ver(Extension(state): Extension>) -> String { + state.cfg.sysinfo_ver.clone() +} + +/// Bare-string body: `"SYSINFO_UPDATED"` or `"ID_NOT_FOUND"`. The client at +/// /Users/sn0/Desktop/rustdesk/src/hbbs_http/sync.rs:212 does a literal +/// `==` comparison on these — do not wrap in JSON. +pub async fn sysinfo( + Extension(state): Extension>, + Json(payload): Json, +) -> Result { + let id = payload + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + let uuid = payload + .get("uuid") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + if id.is_empty() || uuid.is_empty() { + return Err(ApiError::BadRequest("id and uuid required".into())); + } + + // Tie sysinfo storage to a real rendezvous-registered peer. Without this + // gate, any caller could populate device_sysinfo for arbitrary IDs. + let peer = state + .db + .get_peer(id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + if peer.is_none() { + return Ok("ID_NOT_FOUND".to_string()); + } + + let version = parse_version_number(payload.get("version").and_then(|v| v.as_str())); + state + .db + .sysinfo_upsert(id, uuid, &payload.to_string(), &state.cfg.sysinfo_ver, version) + .await?; + Ok("SYSINFO_UPDATED".to_string()) +} + +fn parse_version_number(s: Option<&str>) -> i64 { + let Some(s) = s else { return 0 }; + // hbb_common encodes "1.4.2" as 1*1_000_000 + 4*1_000 + 2 = 1_004_002. + let mut parts = s.split('.').map(|p| p.parse::().unwrap_or(0)); + let major = parts.next().unwrap_or(0); + let minor = parts.next().unwrap_or(0); + let patch = parts.next().unwrap_or(0); + major * 1_000_000 + minor * 1_000 + patch +} diff --git a/src/api/twofa.rs b/src/api/twofa.rs new file mode 100644 index 0000000..ee1bc6b --- /dev/null +++ b/src/api/twofa.rs @@ -0,0 +1,147 @@ +//! `POST /api/2fa/enroll` — admin-only TOTP enrollment. +//! +//! Generates a fresh 20-byte (160-bit) base32 secret, stores it for the +//! target user, and returns: +//! - `secret_b32` — the literal secret to enter into an authenticator app. +//! - `otpauth_url` — the standard `otpauth://totp/...` URL the same apps +//! accept as a QR-code or pasted-string. +//! +//! There is no client-facing UI for this in the desktop app; operators run it +//! by curl after creating the user. M4's `--bootstrap-admin-username` admin +//! is the natural caller. + +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::state::AppState; +use axum::extract::Extension; +use axum::Json; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::sync::Arc; +use totp_rs::Secret; + +#[derive(Debug, Deserialize)] +pub struct EnrollBody { + /// Either `user_id` or `username` is required. `user_id` wins if both + /// are present. + #[serde(default)] + pub user_id: Option, + #[serde(default)] + pub username: Option, + /// Issuer name shown in the authenticator app. Defaults to "RustDesk". + #[serde(default)] + pub issuer: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UnenrollBody { + #[serde(default)] + pub user_id: Option, + #[serde(default)] + pub username: Option, +} + +pub async fn enroll( + Extension(state): Extension>, + caller: AuthedUser, + Json(body): Json, +) -> Result, ApiError> { + if !caller.is_admin { + return Err(ApiError::Forbidden("admin required".into())); + } + let user = resolve_target(&state, body.user_id, body.username.as_deref()).await?; + + // 20 random bytes -> base32 (the standard size for SHA1 TOTP). + let raw = sodiumoxide::randombytes::randombytes(20); + let secret_b32 = Secret::Raw(raw.clone()).to_encoded().to_string(); + + state + .db + .totp_enroll(user.id, &secret_b32) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let issuer = body + .issuer + .as_deref() + .filter(|s| !s.is_empty()) + .unwrap_or("RustDesk"); + // Build the otpauth:// URL manually rather than depend on totp-rs's + // URL helpers (their API has shifted between minor versions). Format + // per https://github.com/google/google-authenticator/wiki/Key-Uri-Format. + let otpauth_url = format!( + "otpauth://totp/{issuer}:{account}?secret={secret}&issuer={issuer}&algorithm=SHA1&digits=6&period=30", + issuer = url_encode(issuer), + account = url_encode(&user.username), + secret = url_encode(&secret_b32), + ); + + Ok(Json(json!({ + "user_id": user.id, + "username": user.username, + "secret_b32": secret_b32, + "otpauth_url": otpauth_url, + }))) +} + +pub async fn unenroll( + Extension(state): Extension>, + caller: AuthedUser, + Json(body): Json, +) -> Result, ApiError> { + if !caller.is_admin { + return Err(ApiError::Forbidden("admin required".into())); + } + let user = resolve_target(&state, body.user_id, body.username.as_deref()).await?; + let removed = state + .db + .totp_unenroll(user.id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(Json(json!({ "removed": removed }))) +} + +/// Minimal percent-encoder for the otpauth URL fields. Encodes anything +/// outside the unreserved URL set (`A-Za-z0-9-_.~`) — keeps the URL short +/// and avoids pulling in `urlencoding` just for this single call site. +fn url_encode(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for b in s.as_bytes() { + match b { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + out.push(*b as char); + } + _ => { + use std::fmt::Write; + let _ = write!(out, "%{:02X}", b); + } + } + } + out +} + +async fn resolve_target( + state: &AppState, + user_id: Option, + username: Option<&str>, +) -> Result { + if let Some(id) = user_id { + return state + .db + .user_find_by_id(id) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or(ApiError::NotFound); + } + if let Some(name) = username.filter(|s| !s.is_empty()) { + return state + .db + .user_find_by_username(name) + .await + .map_err(|e| ApiError::Internal(e.to_string()))? + .ok_or(ApiError::NotFound); + } + Err(ApiError::BadRequest( + "user_id or username required".into(), + )) +} diff --git a/src/api/users.rs b/src/api/users.rs new file mode 100644 index 0000000..abadf54 --- /dev/null +++ b/src/api/users.rs @@ -0,0 +1,68 @@ +use crate::api::error::ApiError; +use crate::api::middleware::AuthedUser; +use crate::api::pagination::{Page, PageQuery}; +use crate::api::state::AppState; +use crate::database::UserRow; +use axum::extract::{Extension, Query}; +use axum::Json; +use hbb_common::ResultType; +use serde::Serialize; +use std::sync::Arc; + +#[derive(Debug, Serialize)] +pub struct UserPayload { + pub name: String, + pub display_name: String, + pub avatar: String, + pub email: String, + pub note: String, + pub status: i64, + pub is_admin: bool, +} + +impl From<&UserRow> for UserPayload { + fn from(u: &UserRow) -> Self { + Self { + name: u.username.clone(), + display_name: u.display_name.clone(), + avatar: u.avatar.clone(), + email: u.email.clone(), + note: u.note.clone(), + status: u.status, + is_admin: u.is_admin, + } + } +} + +pub async fn hash_password(plain: String) -> ResultType { + Ok( + hbb_common::tokio::task::spawn_blocking(move || bcrypt::hash(plain, 10)) + .await??, + ) +} + +pub async fn verify_password(hash: String, plain: String) -> ResultType { + Ok( + hbb_common::tokio::task::spawn_blocking(move || bcrypt::verify(plain, &hash)) + .await??, + ) +} + +/// `GET /api/users` — paginated list of users visible to the caller. Admin +/// sees all enabled users; non-admin sees themselves plus members of any +/// device-group they share. Flutter decoder at common/hbbs/hbbs.dart:26. +pub async fn list( + Extension(state): Extension>, + user: AuthedUser, + Query(q): Query, +) -> Result>, ApiError> { + let (total, rows) = state + .db + .users_list_accessible(user.user_id, user.is_admin, q.offset(), q.limit()) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + Ok(Json(Page { + total, + data: rows.iter().map(UserPayload::from).collect(), + })) +} diff --git a/src/database.rs b/src/database.rs index fa1b6ed..803e23e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; use hbb_common::{log, ResultType}; use sqlx::{ - sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, SqliteConnection, + sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, Row, + SqliteConnection, }; use std::{ops::DerefMut, str::FromStr}; //use sqlx::postgres::PgPoolOptions; @@ -46,6 +47,148 @@ pub struct Peer { pub status: Option, } +#[derive(Debug, Clone)] +pub struct UserRow { + pub id: i64, + pub username: String, + pub password_hash: String, + pub display_name: String, + pub email: String, + pub note: String, + pub avatar: String, + pub status: i64, + pub is_admin: bool, +} + +pub struct NewUser<'a> { + pub username: &'a str, + pub password_hash: &'a str, + pub display_name: &'a str, + pub is_admin: bool, +} + +#[derive(Debug, Clone)] +pub struct DeviceSysinfoRow { + pub payload: String, + pub sysinfo_ver_seen: String, +} + +#[derive(Debug, Clone)] +pub struct AbProfileRow { + pub guid: String, + pub name: String, + pub owner: String, + pub note: String, + pub rule: i64, + pub info_json: Option, +} + +#[derive(Debug, Clone, Default)] +pub struct AbPeerRow { + pub id: String, + pub alias: String, + pub note: String, + pub password: String, + pub hash: String, + pub username: String, + pub hostname: String, + pub platform: String, + pub tags: Vec, +} + +#[derive(Debug, Clone)] +pub struct AbTagRow { + pub name: String, + pub color: i64, +} + +#[derive(Debug, Clone)] +pub struct DeviceGroupRow { + pub id: i64, + pub name: String, +} + +#[derive(Debug, Clone, Default)] +pub struct PeerListRow { + pub id: String, + pub owner_username: String, + pub owner_display_name: String, + pub device_group_name: String, + pub note: String, + pub status: i64, + /// Raw sysinfo JSON; the handler parses and emits a trimmed `info` object. + pub sysinfo_payload: String, +} + +pub struct AbPeerInsert<'a> { + pub id: &'a str, + pub alias: Option<&'a str>, + pub note: Option<&'a str>, + pub password: Option<&'a str>, + pub hash: Option<&'a str>, + pub username: Option<&'a str>, + pub hostname: Option<&'a str>, + pub platform: Option<&'a str>, +} + +#[derive(Debug, Clone, Default)] +pub struct ResolvedStrategy { + pub modified_at: i64, + /// JSON object map; passed straight into the heartbeat response. + pub config_options_json: String, + pub extra_json: String, +} + +#[derive(Debug, Clone)] +pub struct HeartbeatCommand { + pub kind: String, + pub payload: Option, +} + +#[derive(Debug, Clone)] +pub struct RecordingFile { + pub size: i64, + pub state: String, +} + +#[derive(Debug, Clone)] +pub struct OidcProviderRow { + pub name: String, + pub display_name: Option, + pub icon_url: Option, + pub issuer_url: String, + pub client_id: String, + pub client_secret: String, + pub scopes: String, + pub redirect_url: String, + pub enabled: bool, +} + +pub struct OidcSessionInsert<'a> { + pub code: &'a str, + pub provider: &'a str, + pub state: &'a str, + pub client_id_str: &'a str, + pub client_uuid: &'a str, + pub device_info_json: &'a str, + pub expires_at: i64, +} + +#[derive(Debug, Clone)] +pub struct OidcSessionRow { + pub code: String, + pub provider: String, + pub state: String, + pub client_id_str: String, + pub client_uuid: String, + pub device_info_json: String, + pub expires_at: i64, + pub status: String, + pub access_token: Option, + pub user_id: Option, + pub error: Option, +} + impl Database { pub async fn new(url: &str) -> ResultType { if !std::path::Path::new(url).exists() { @@ -90,9 +233,1741 @@ impl Database { ) .execute(self.pool.get().await?.deref_mut()) .await?; + // M1 schema: users, tokens, device_sysinfo. Runtime form so first-time + // builds don't require DATABASE_URL to already contain these tables. + for stmt in M1_SCHEMA { + sqlx::query(stmt) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + // M2 schema: address books, tags, device groups, accessibility view. + for stmt in M2_SCHEMA { + sqlx::query(stmt) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + // M3 schema: audit log, recordings, strategies, heartbeat commands. + for stmt in M3_SCHEMA { + sqlx::query(stmt) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + // M4 schema: 2FA / email-code / OIDC scaffolding. + for stmt in M4_SCHEMA { + sqlx::query(stmt) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + // Soft-ALTERs run after schema creation. SQLite < 3.35 lacks + // `ADD COLUMN IF NOT EXISTS`; swallow the duplicate-column error + // so re-runs are idempotent. + for stmt in M2_SOFT_ALTERS { + self.try_alter(stmt).await; + } Ok(()) } + async fn try_alter(&self, sql: &str) { + match sqlx::query(sql) + .execute(self.pool.get().await.unwrap().deref_mut()) + .await + { + Ok(_) => {} + Err(e) => { + let msg = e.to_string(); + if !msg.contains("duplicate column name") { + log::warn!("schema migration `{}` failed: {}", sql, msg); + } + } + } + } + + pub async fn count_users(&self) -> ResultType { + let row = sqlx::query("SELECT COUNT(*) AS c FROM users") + .fetch_one(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.try_get::("c")?) + } + + pub async fn warn_if_no_users(&self) { + match self.count_users().await { + Ok(0) => log::warn!( + "users table is empty and no --bootstrap-admin-username/password supplied; \ + /api/login will reject every request" + ), + Ok(_) => {} + Err(e) => log::warn!("count_users failed: {}", e), + } + } + + pub async fn bootstrap_admin(&self, username: &str, password_plain: &str) -> ResultType<()> { + if self.count_users().await? > 0 { + return Ok(()); + } + let plain = password_plain.to_owned(); + let hash = + hbb_common::tokio::task::spawn_blocking(move || bcrypt::hash(plain, 10)).await??; + let display = "Admin"; + let is_admin: i64 = 1; + let status: i64 = 1; + sqlx::query( + "INSERT INTO users(username, password_hash, display_name, status, is_admin) \ + VALUES(?, ?, ?, ?, ?)", + ) + .bind(username) + .bind(&hash) + .bind(display) + .bind(status) + .bind(is_admin) + .execute(self.pool.get().await?.deref_mut()) + .await?; + log::info!("bootstrap admin '{}' created", username); + Ok(()) + } + + pub async fn user_find_by_username(&self, username: &str) -> ResultType> { + let row = sqlx::query( + "SELECT id, username, password_hash, display_name, email, note, avatar, status, is_admin \ + FROM users WHERE username = ?", + ) + .bind(username) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(row_to_user)) + } + + pub async fn user_find_by_id(&self, id: i64) -> ResultType> { + let row = sqlx::query( + "SELECT id, username, password_hash, display_name, email, note, avatar, status, is_admin \ + FROM users WHERE id = ?", + ) + .bind(id) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(row_to_user)) + } + + pub async fn user_insert(&self, u: NewUser<'_>) -> ResultType { + let admin_int: i64 = if u.is_admin { 1 } else { 0 }; + let res = sqlx::query( + "INSERT INTO users(username, password_hash, display_name, is_admin, status) \ + VALUES(?, ?, ?, ?, 1)", + ) + .bind(u.username) + .bind(u.password_hash) + .bind(u.display_name) + .bind(admin_int) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(res.last_insert_rowid()) + } + + pub async fn token_insert( + &self, + user_id: i64, + sha: &[u8], + peer_id: &str, + peer_uuid: &str, + device_info: &str, + ttl_secs: i64, + ) -> ResultType<()> { + let expires_at = chrono::Utc::now() + chrono::Duration::seconds(ttl_secs); + sqlx::query( + "INSERT INTO tokens(user_id, token_sha256, peer_id, peer_uuid, device_info, expires_at) \ + VALUES(?, ?, ?, ?, ?, ?)", + ) + .bind(user_id) + .bind(sha) + .bind(peer_id) + .bind(peer_uuid) + .bind(device_info) + .bind(expires_at) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + /// Return (user_id, expires_at_unix) for a token still valid at `now`. + pub async fn token_lookup(&self, sha: &[u8]) -> ResultType> { + let row = sqlx::query( + "SELECT user_id, strftime('%s', expires_at) AS exp \ + FROM tokens WHERE token_sha256 = ?", + ) + .bind(sha) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + let Some(row) = row else { return Ok(None) }; + let user_id: i64 = row.try_get("user_id")?; + let exp_str: String = row.try_get("exp")?; + let exp: i64 = exp_str.parse().unwrap_or(0); + if exp <= chrono::Utc::now().timestamp() { + return Ok(None); + } + Ok(Some((user_id, exp))) + } + + pub async fn token_touch(&self, sha: &[u8], ttl_secs: i64) -> ResultType<()> { + let expires_at = chrono::Utc::now() + chrono::Duration::seconds(ttl_secs); + sqlx::query( + "UPDATE tokens SET last_used_at = current_timestamp, expires_at = ? \ + WHERE token_sha256 = ?", + ) + .bind(expires_at) + .bind(sha) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn token_delete(&self, sha: &[u8]) -> ResultType<()> { + sqlx::query("DELETE FROM tokens WHERE token_sha256 = ?") + .bind(sha) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn token_purge_expired(&self) -> ResultType { + let res = sqlx::query("DELETE FROM tokens WHERE expires_at <= current_timestamp") + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(res.rows_affected()) + } + + /// Update last_heartbeat_at for a device, inserting the row if missing. + /// Returns true when the cached `sysinfo_ver_seen` differs from `cfg_ver`, + /// signaling the client to re-upload sysinfo. + pub async fn sysinfo_heartbeat( + &self, + id: &str, + uuid: &str, + version: i64, + conns_json: &str, + cfg_ver: &str, + ) -> ResultType { + let existing = sqlx::query( + "SELECT sysinfo_ver_seen FROM device_sysinfo WHERE id = ? AND uuid = ?", + ) + .bind(id) + .bind(uuid) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + match existing { + Some(row) => { + let seen: String = row.try_get("sysinfo_ver_seen")?; + sqlx::query( + "UPDATE device_sysinfo SET version = ?, conns = ?, \ + last_heartbeat_at = current_timestamp, last_seen_at = current_timestamp \ + WHERE id = ? AND uuid = ?", + ) + .bind(version) + .bind(conns_json) + .bind(id) + .bind(uuid) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(seen != cfg_ver) + } + None => { + sqlx::query( + "INSERT INTO device_sysinfo(id, uuid, version, conns) VALUES(?, ?, ?, ?)", + ) + .bind(id) + .bind(uuid) + .bind(version) + .bind(conns_json) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(true) + } + } + } + + /// Persist the full sysinfo payload and mark the device as having acked + /// the current `cfg_ver`. Caller must check that a `peer` row exists for + /// `id` before calling — see api::sysinfo for the ID_NOT_FOUND case. + pub async fn sysinfo_upsert( + &self, + id: &str, + uuid: &str, + payload: &str, + cfg_ver: &str, + version: i64, + ) -> ResultType<()> { + sqlx::query( + "INSERT INTO device_sysinfo(id, uuid, payload, sysinfo_ver_seen, version, updated_at) \ + VALUES(?, ?, ?, ?, ?, current_timestamp) \ + ON CONFLICT(id, uuid) DO UPDATE SET \ + payload = excluded.payload, \ + sysinfo_ver_seen = excluded.sysinfo_ver_seen, \ + version = excluded.version, \ + updated_at = current_timestamp", + ) + .bind(id) + .bind(uuid) + .bind(payload) + .bind(cfg_ver) + .bind(version) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + // =================================================================== + // M2: address book / tags / device groups / accessibility + // =================================================================== + + /// Bind a device (peer_id, peer_uuid) to a user. Upserts so the binding + /// sticks even when the device hasn't sysinfo'd yet — `--assign` from + /// fresh installs is a real flow. Subsequent `sysinfo_heartbeat` calls + /// then UPDATE the existing row and preserve `user_id`. + pub async fn device_claim(&self, user_id: i64, peer_id: &str, peer_uuid: &str) { + if peer_id.is_empty() || peer_uuid.is_empty() { + return; + } + let res = sqlx::query( + "INSERT INTO device_sysinfo(id, uuid, user_id) VALUES(?, ?, ?) \ + ON CONFLICT(id, uuid) DO UPDATE SET user_id = excluded.user_id", + ) + .bind(peer_id) + .bind(peer_uuid) + .bind(user_id) + .execute(self.pool.get().await.unwrap().deref_mut()) + .await; + if let Err(e) = res { + log::warn!("device_claim failed: {}", e); + } + } + + /// Look up the personal AB for a user, creating it if missing. + pub async fn ab_get_or_create_personal(&self, user_id: i64) -> ResultType { + let row = sqlx::query("SELECT guid FROM address_books WHERE owner_user_id = ? AND kind = 0") + .bind(user_id) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + if let Some(r) = row { + return Ok(r.try_get::("guid")?); + } + let guid = uuid::Uuid::new_v4().to_string(); + sqlx::query( + "INSERT INTO address_books(guid, owner_user_id, name, kind) VALUES(?, ?, ?, 0)", + ) + .bind(&guid) + .bind(user_id) + .bind("My address book") + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(guid) + } + + /// Resolve the maximum effective rule for `user_id` against `ab_guid`. + /// Returns 3 (Full) for the owner, the largest matching rule across + /// direct user shares and device-group shares, or None if no access. + pub async fn ab_resolve_rule(&self, user_id: i64, ab_guid: &str) -> ResultType> { + // Owner check first. + let row = sqlx::query("SELECT owner_user_id FROM address_books WHERE guid = ?") + .bind(ab_guid) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + let Some(row) = row else { return Ok(None) }; + let owner: i64 = row.try_get("owner_user_id")?; + if owner == user_id { + return Ok(Some(3)); + } + // Direct or via device-group. + let row = sqlx::query( + "SELECT MAX(rule) AS r FROM address_book_shares \ + WHERE ab_guid = ? AND ( \ + user_id = ? OR \ + group_id IN (SELECT device_group_id FROM device_group_members WHERE user_id = ?) \ + )", + ) + .bind(ab_guid) + .bind(user_id) + .bind(user_id) + .fetch_one(self.pool.get().await?.deref_mut()) + .await?; + let rule: Option = row.try_get("r").ok(); + Ok(rule) + } + + /// List address books shared with `user_id` (excludes their personal AB). + /// Returns (total, page). + pub async fn ab_list_shared_for_user( + &self, + user_id: i64, + offset: i64, + limit: i64, + ) -> ResultType<(i64, Vec)> { + // total + let total_row = sqlx::query( + "SELECT COUNT(DISTINCT ab.guid) AS c \ + FROM address_books ab \ + JOIN address_book_shares s ON s.ab_guid = ab.guid \ + WHERE ab.kind = 1 AND ( \ + s.user_id = ? OR \ + s.group_id IN (SELECT device_group_id FROM device_group_members WHERE user_id = ?) \ + )", + ) + .bind(user_id) + .bind(user_id) + .fetch_one(self.pool.get().await?.deref_mut()) + .await?; + let total: i64 = total_row.try_get("c")?; + // page + let rows = sqlx::query( + "SELECT ab.guid, ab.name, ab.note, ab.info_json, MAX(s.rule) AS rule, \ + COALESCE(u.username, '') AS owner \ + FROM address_books ab \ + JOIN address_book_shares s ON s.ab_guid = ab.guid \ + LEFT JOIN users u ON u.id = ab.owner_user_id \ + WHERE ab.kind = 1 AND ( \ + s.user_id = ? OR \ + s.group_id IN (SELECT device_group_id FROM device_group_members WHERE user_id = ?) \ + ) \ + GROUP BY ab.guid \ + ORDER BY ab.name \ + LIMIT ? OFFSET ?", + ) + .bind(user_id) + .bind(user_id) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + let data = rows + .into_iter() + .map(|r| AbProfileRow { + guid: r.try_get("guid").unwrap_or_default(), + name: r.try_get("name").unwrap_or_default(), + owner: r.try_get("owner").unwrap_or_default(), + note: r.try_get::, _>("note").unwrap_or_default().unwrap_or_default(), + rule: r.try_get("rule").unwrap_or(1), + info_json: r.try_get::, _>("info_json").ok().flatten(), + }) + .collect(); + Ok((total, data)) + } + + /// Page through peers in an address book. Tags are not filled in yet — + /// callers loop over peers and call `ab_peer_tags`. + pub async fn ab_list_peers( + &self, + ab_guid: &str, + offset: i64, + limit: i64, + ) -> ResultType<(i64, Vec)> { + let total_row = + sqlx::query("SELECT COUNT(*) AS c FROM address_book_peers WHERE ab_guid = ?") + .bind(ab_guid) + .fetch_one(self.pool.get().await?.deref_mut()) + .await?; + let total: i64 = total_row.try_get("c")?; + let rows = sqlx::query( + "SELECT peer_id, alias, note, password, hash, username, hostname, platform \ + FROM address_book_peers WHERE ab_guid = ? \ + ORDER BY peer_id \ + LIMIT ? OFFSET ?", + ) + .bind(ab_guid) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + let mut peers: Vec = rows + .into_iter() + .map(|r| AbPeerRow { + id: r.try_get("peer_id").unwrap_or_default(), + alias: r.try_get::, _>("alias").unwrap_or_default().unwrap_or_default(), + note: r.try_get::, _>("note").unwrap_or_default().unwrap_or_default(), + password: r.try_get::, _>("password").unwrap_or_default().unwrap_or_default(), + hash: r.try_get::, _>("hash").unwrap_or_default().unwrap_or_default(), + username: r.try_get::, _>("username").unwrap_or_default().unwrap_or_default(), + hostname: r.try_get::, _>("hostname").unwrap_or_default().unwrap_or_default(), + platform: r.try_get::, _>("platform").unwrap_or_default().unwrap_or_default(), + tags: vec![], + }) + .collect(); + // Fill in tags. One-shot bulk fetch keeps it O(1) round-trip per page. + if !peers.is_empty() { + let tag_rows = sqlx::query( + "SELECT peer_id, tag_name FROM address_book_peer_tags WHERE ab_guid = ?", + ) + .bind(ab_guid) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + let mut by_peer: std::collections::HashMap> = + std::collections::HashMap::new(); + for row in tag_rows { + let pid: String = row.try_get("peer_id").unwrap_or_default(); + let tag: String = row.try_get("tag_name").unwrap_or_default(); + by_peer.entry(pid).or_default().push(tag); + } + for p in peers.iter_mut() { + if let Some(tags) = by_peer.remove(&p.id) { + p.tags = tags; + } + } + } + Ok((total, peers)) + } + + pub async fn ab_count_peers(&self, ab_guid: &str) -> ResultType { + let row = + sqlx::query("SELECT COUNT(*) AS c FROM address_book_peers WHERE ab_guid = ?") + .bind(ab_guid) + .fetch_one(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.try_get("c")?) + } + + pub async fn ab_list_tags(&self, ab_guid: &str) -> ResultType> { + let rows = sqlx::query( + "SELECT name, color FROM address_book_tags WHERE ab_guid = ? ORDER BY name", + ) + .bind(ab_guid) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + Ok(rows + .into_iter() + .map(|r| AbTagRow { + name: r.try_get("name").unwrap_or_default(), + color: r.try_get("color").unwrap_or(0), + }) + .collect()) + } + + pub async fn ab_peer_insert( + &self, + ab_guid: &str, + p: AbPeerInsert<'_>, + tags: Option<&[String]>, + ) -> ResultType<()> { + sqlx::query( + "INSERT INTO address_book_peers \ + (ab_guid, peer_id, alias, note, password, hash, username, hostname, platform) \ + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)", + ) + .bind(ab_guid) + .bind(p.id) + .bind(p.alias) + .bind(p.note) + .bind(p.password) + .bind(p.hash) + .bind(p.username) + .bind(p.hostname) + .bind(p.platform) + .execute(self.pool.get().await?.deref_mut()) + .await?; + if let Some(tags) = tags { + self.ab_peer_replace_tags(ab_guid, p.id, tags).await?; + } + Ok(()) + } + + /// Partial peer update — only fields present in `body` are touched. + pub async fn ab_peer_partial_update( + &self, + ab_guid: &str, + peer_id: &str, + body: &serde_json::Value, + ) -> ResultType { + // Build a dynamic SET list. Restrict to a known column set for safety. + let cols = [ + "alias", "note", "password", "hash", "username", "hostname", "platform", + ]; + let mut sets: Vec<&str> = vec![]; + for c in cols.iter() { + if body.get(*c).is_some() { + sets.push(*c); + } + } + if !sets.is_empty() { + let setlist = sets + .iter() + .map(|c| format!("{} = ?", c)) + .collect::>() + .join(", "); + let sql = format!( + "UPDATE address_book_peers SET {}, updated_at = strftime('%s','now') \ + WHERE ab_guid = ? AND peer_id = ?", + setlist + ); + let mut q = sqlx::query(&sql); + for c in &sets { + q = q.bind(body.get(*c).and_then(|v| v.as_str()).unwrap_or("")); + } + q = q.bind(ab_guid).bind(peer_id); + let res = q.execute(self.pool.get().await?.deref_mut()).await?; + if res.rows_affected() == 0 { + return Ok(false); + } + } + // Tags update if present. + if let Some(tags_v) = body.get("tags") { + if let Some(arr) = tags_v.as_array() { + let tags: Vec = arr + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + self.ab_peer_replace_tags(ab_guid, peer_id, &tags).await?; + } + } + Ok(true) + } + + async fn ab_peer_replace_tags( + &self, + ab_guid: &str, + peer_id: &str, + tags: &[String], + ) -> ResultType<()> { + sqlx::query("DELETE FROM address_book_peer_tags WHERE ab_guid = ? AND peer_id = ?") + .bind(ab_guid) + .bind(peer_id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + for t in tags { + // Ensure the tag row exists; if missing, insert with a default color + // (Flutter's transparent black). Operators can fix later. + sqlx::query( + "INSERT OR IGNORE INTO address_book_tags(ab_guid, name, color) VALUES(?, ?, 0)", + ) + .bind(ab_guid) + .bind(t) + .execute(self.pool.get().await?.deref_mut()) + .await?; + sqlx::query( + "INSERT OR IGNORE INTO address_book_peer_tags(ab_guid, peer_id, tag_name) \ + VALUES(?, ?, ?)", + ) + .bind(ab_guid) + .bind(peer_id) + .bind(t) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + Ok(()) + } + + pub async fn ab_peers_delete(&self, ab_guid: &str, ids: &[String]) -> ResultType { + let mut total: u64 = 0; + for id in ids { + let res = sqlx::query( + "DELETE FROM address_book_peers WHERE ab_guid = ? AND peer_id = ?", + ) + .bind(ab_guid) + .bind(id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + total += res.rows_affected(); + } + Ok(total) + } + + pub async fn ab_tag_insert(&self, ab_guid: &str, name: &str, color: i64) -> ResultType<()> { + sqlx::query( + "INSERT OR REPLACE INTO address_book_tags(ab_guid, name, color) VALUES(?, ?, ?)", + ) + .bind(ab_guid) + .bind(name) + .bind(color) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn ab_tag_rename(&self, ab_guid: &str, old: &str, new: &str) -> ResultType<()> { + // Two-step rename to keep peer_tags in sync. SQLite has no UPDATE + // CASCADE so we touch both tables explicitly inside one transaction. + // The deadpool guard must outlive the transaction borrow, hence the + // explicit `let` binding. + let mut guard = self.pool.get().await?; + let conn: &mut SqliteConnection = guard.deref_mut(); + let mut tx = conn.begin().await?; + sqlx::query("UPDATE address_book_tags SET name = ? WHERE ab_guid = ? AND name = ?") + .bind(new) + .bind(ab_guid) + .bind(old) + .execute(&mut tx) + .await?; + sqlx::query( + "UPDATE address_book_peer_tags SET tag_name = ? WHERE ab_guid = ? AND tag_name = ?", + ) + .bind(new) + .bind(ab_guid) + .bind(old) + .execute(&mut tx) + .await?; + tx.commit().await?; + Ok(()) + } + + pub async fn ab_tag_update_color( + &self, + ab_guid: &str, + name: &str, + color: i64, + ) -> ResultType<()> { + sqlx::query( + "UPDATE address_book_tags SET color = ? WHERE ab_guid = ? AND name = ?", + ) + .bind(color) + .bind(ab_guid) + .bind(name) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn ab_tags_delete(&self, ab_guid: &str, names: &[String]) -> ResultType { + let mut total: u64 = 0; + for n in names { + let res = + sqlx::query("DELETE FROM address_book_tags WHERE ab_guid = ? AND name = ?") + .bind(ab_guid) + .bind(n) + .execute(self.pool.get().await?.deref_mut()) + .await?; + total += res.rows_affected(); + } + Ok(total) + } + + /// Replace the personal AB's contents wholesale — used by the legacy + /// `POST /api/ab` endpoint. Drops all peers and tags, then re-inserts. + pub async fn ab_legacy_replace( + &self, + ab_guid: &str, + tags: &[(String, i64)], + peers: &[AbPeerRow], + ) -> ResultType<()> { + let mut guard = self.pool.get().await?; + let conn: &mut SqliteConnection = guard.deref_mut(); + let mut tx = conn.begin().await?; + sqlx::query("DELETE FROM address_book_peer_tags WHERE ab_guid = ?") + .bind(ab_guid) + .execute(&mut tx) + .await?; + sqlx::query("DELETE FROM address_book_peers WHERE ab_guid = ?") + .bind(ab_guid) + .execute(&mut tx) + .await?; + sqlx::query("DELETE FROM address_book_tags WHERE ab_guid = ?") + .bind(ab_guid) + .execute(&mut tx) + .await?; + for (name, color) in tags { + sqlx::query( + "INSERT INTO address_book_tags(ab_guid, name, color) VALUES(?, ?, ?)", + ) + .bind(ab_guid) + .bind(name) + .bind(color) + .execute(&mut tx) + .await?; + } + for p in peers { + sqlx::query( + "INSERT INTO address_book_peers \ + (ab_guid, peer_id, alias, note, password, hash, username, hostname, platform) \ + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)", + ) + .bind(ab_guid) + .bind(&p.id) + .bind(&p.alias) + .bind(&p.note) + .bind(&p.password) + .bind(&p.hash) + .bind(&p.username) + .bind(&p.hostname) + .bind(&p.platform) + .execute(&mut tx) + .await?; + for t in &p.tags { + sqlx::query( + "INSERT OR IGNORE INTO address_book_peer_tags(ab_guid, peer_id, tag_name) \ + VALUES(?, ?, ?)", + ) + .bind(ab_guid) + .bind(&p.id) + .bind(t) + .execute(&mut tx) + .await?; + } + } + tx.commit().await?; + Ok(()) + } + + /// Device groups visible to a user (admin sees all; non-admin sees groups + /// they're a member of). Returns (total, page). + pub async fn groups_list_for_user( + &self, + user_id: i64, + is_admin: bool, + offset: i64, + limit: i64, + ) -> ResultType<(i64, Vec)> { + if is_admin { + let total: i64 = sqlx::query("SELECT COUNT(*) AS c FROM device_groups") + .fetch_one(self.pool.get().await?.deref_mut()) + .await? + .try_get("c")?; + let rows = sqlx::query( + "SELECT id, name FROM device_groups ORDER BY name LIMIT ? OFFSET ?", + ) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + return Ok(( + total, + rows.into_iter() + .map(|r| DeviceGroupRow { + id: r.try_get("id").unwrap_or(0), + name: r.try_get("name").unwrap_or_default(), + }) + .collect(), + )); + } + let total: i64 = sqlx::query( + "SELECT COUNT(*) AS c FROM device_groups dg \ + JOIN device_group_members m ON m.device_group_id = dg.id \ + WHERE m.user_id = ?", + ) + .bind(user_id) + .fetch_one(self.pool.get().await?.deref_mut()) + .await? + .try_get("c")?; + let rows = sqlx::query( + "SELECT dg.id, dg.name FROM device_groups dg \ + JOIN device_group_members m ON m.device_group_id = dg.id \ + WHERE m.user_id = ? \ + ORDER BY dg.name LIMIT ? OFFSET ?", + ) + .bind(user_id) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + Ok(( + total, + rows.into_iter() + .map(|r| DeviceGroupRow { + id: r.try_get("id").unwrap_or(0), + name: r.try_get("name").unwrap_or_default(), + }) + .collect(), + )) + } + + /// Users visible to a viewer. Admin sees all enabled users; non-admin + /// sees themselves plus any user they share a device-group with. + pub async fn users_list_accessible( + &self, + viewer_id: i64, + is_admin: bool, + offset: i64, + limit: i64, + ) -> ResultType<(i64, Vec)> { + let (count_sql, list_sql): (&str, String) = if is_admin { + ( + "SELECT COUNT(*) AS c FROM users WHERE status = 1", + "SELECT id, username, password_hash, display_name, email, note, avatar, status, is_admin \ + FROM users WHERE status = 1 ORDER BY username LIMIT ? OFFSET ?".to_string(), + ) + } else { + ( + "SELECT COUNT(DISTINCT u.id) AS c FROM users u WHERE u.status = 1 AND ( \ + u.id = ? OR \ + u.id IN ( \ + SELECT m2.user_id FROM device_group_members m1 \ + JOIN device_group_members m2 USING(device_group_id) \ + WHERE m1.user_id = ? \ + ) \ + )", + "SELECT DISTINCT u.id, u.username, u.password_hash, u.display_name, u.email, \ + u.note, u.avatar, u.status, u.is_admin \ + FROM users u WHERE u.status = 1 AND ( \ + u.id = ? OR \ + u.id IN ( \ + SELECT m2.user_id FROM device_group_members m1 \ + JOIN device_group_members m2 USING(device_group_id) \ + WHERE m1.user_id = ? \ + ) \ + ) \ + ORDER BY u.username LIMIT ? OFFSET ?".to_string(), + ) + }; + let total: i64 = if is_admin { + sqlx::query(count_sql) + .fetch_one(self.pool.get().await?.deref_mut()) + .await? + .try_get("c")? + } else { + sqlx::query(count_sql) + .bind(viewer_id) + .bind(viewer_id) + .fetch_one(self.pool.get().await?.deref_mut()) + .await? + .try_get("c")? + }; + let rows = if is_admin { + sqlx::query(&list_sql) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await? + } else { + sqlx::query(&list_sql) + .bind(viewer_id) + .bind(viewer_id) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await? + }; + Ok((total, rows.into_iter().map(row_to_user).collect())) + } + + /// Peers visible to a viewer. Admin sees all; non-admin sees peers they + /// own plus peers owned by users they share a device-group with. + pub async fn peers_list_accessible( + &self, + viewer_id: i64, + is_admin: bool, + offset: i64, + limit: i64, + ) -> ResultType<(i64, Vec)> { + // Common select: device_sysinfo joined to its owner. We pick the + // alphabetically-first device-group name as the surfaced group. + let where_clause = if is_admin { + "1 = 1" + } else { + "(ds.user_id = ? OR ds.user_id IN ( \ + SELECT m2.user_id FROM device_group_members m1 \ + JOIN device_group_members m2 USING(device_group_id) \ + WHERE m1.user_id = ? \ + ))" + }; + let count_sql = format!( + "SELECT COUNT(*) AS c FROM device_sysinfo ds WHERE {}", + where_clause + ); + let list_sql = format!( + "SELECT ds.id AS pid, \ + COALESCE(u.username, '') AS owner_username, \ + COALESCE(u.display_name, '') AS owner_display_name, \ + COALESCE(u.status, 1) AS owner_status, \ + ds.payload AS sysinfo, \ + ( SELECT dg.name FROM device_groups dg \ + JOIN device_group_members mm ON mm.device_group_id = dg.id \ + WHERE mm.user_id = ds.user_id ORDER BY dg.name LIMIT 1 \ + ) AS device_group_name \ + FROM device_sysinfo ds \ + LEFT JOIN users u ON u.id = ds.user_id \ + WHERE {} \ + ORDER BY ds.id LIMIT ? OFFSET ?", + where_clause + ); + let total: i64 = if is_admin { + sqlx::query(&count_sql) + .fetch_one(self.pool.get().await?.deref_mut()) + .await? + .try_get("c")? + } else { + sqlx::query(&count_sql) + .bind(viewer_id) + .bind(viewer_id) + .fetch_one(self.pool.get().await?.deref_mut()) + .await? + .try_get("c")? + }; + let rows = if is_admin { + sqlx::query(&list_sql) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await? + } else { + sqlx::query(&list_sql) + .bind(viewer_id) + .bind(viewer_id) + .bind(limit) + .bind(offset) + .fetch_all(self.pool.get().await?.deref_mut()) + .await? + }; + let data = rows + .into_iter() + .map(|r| PeerListRow { + id: r.try_get("pid").unwrap_or_default(), + owner_username: r.try_get("owner_username").unwrap_or_default(), + owner_display_name: r.try_get("owner_display_name").unwrap_or_default(), + device_group_name: r + .try_get::, _>("device_group_name") + .ok() + .flatten() + .unwrap_or_default(), + note: String::new(), + status: r.try_get("owner_status").unwrap_or(1), + sysinfo_payload: r.try_get("sysinfo").unwrap_or_default(), + }) + .collect(); + Ok((total, data)) + } + + // =================================================================== + // M3: audit / recordings / strategies / heartbeat commands + // =================================================================== + + /// Insert an `audit_conn` row and return its GUID. The client treats the + /// guid as opaque and passes it back later in `PUT /api/audit` to attach + /// an end-of-session note. + pub async fn audit_conn_insert( + &self, + peer_id: &str, + conn_id: i64, + session_id: i64, + ip: &str, + action: &str, + ) -> ResultType { + let guid = uuid::Uuid::new_v4().to_string(); + sqlx::query( + "INSERT INTO audit_conn(guid, peer_id, conn_id, session_id, ip, action, started_at) \ + VALUES(?, ?, ?, ?, ?, ?, strftime('%s','now'))", + ) + .bind(&guid) + .bind(peer_id) + .bind(conn_id) + .bind(session_id) + .bind(ip) + .bind(action) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(guid) + } + + pub async fn audit_conn_update_note(&self, guid: &str, note: &str) -> ResultType { + let res = sqlx::query("UPDATE audit_conn SET note = ? WHERE guid = ?") + .bind(note) + .bind(guid) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(res.rows_affected() > 0) + } + + pub async fn audit_file_insert( + &self, + peer_id: &str, + remote_peer: &str, + direction: i64, + path: &str, + is_file: bool, + info_json: &str, + ) -> ResultType<()> { + sqlx::query( + "INSERT INTO audit_file(peer_id, remote_peer, direction, path, is_file, info_json) \ + VALUES(?, ?, ?, ?, ?, ?)", + ) + .bind(peer_id) + .bind(remote_peer) + .bind(direction) + .bind(path) + .bind(if is_file { 1 } else { 0 }) + .bind(info_json) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn audit_alarm_insert( + &self, + peer_id: &str, + typ: i64, + info_json: &str, + ) -> ResultType<()> { + sqlx::query( + "INSERT INTO audit_alarm(peer_id, typ, info_json) VALUES(?, ?, ?)", + ) + .bind(peer_id) + .bind(typ) + .bind(info_json) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + /// Bulk delete audit rows older than `days` days. Returns number deleted + /// across the three tables. Used by the optional retention sweep. + pub async fn audit_purge_older_than(&self, days: i64) -> ResultType { + if days <= 0 { + return Ok(0); + } + let cutoff = chrono::Utc::now().timestamp() - days * 86400; + let mut total: u64 = 0; + for sql in [ + "DELETE FROM audit_conn WHERE started_at < ?", + "DELETE FROM audit_file WHERE at < ?", + "DELETE FROM audit_alarm WHERE at < ?", + ] { + let res = sqlx::query(sql) + .bind(cutoff) + .execute(self.pool.get().await?.deref_mut()) + .await?; + total += res.rows_affected(); + } + Ok(total) + } + + // ----- Recordings (DB rows; on-disk I/O lives in api::record::storage) ----- + + pub async fn recording_new(&self, peer_id: &str, filename: &str) -> ResultType<()> { + sqlx::query( + "INSERT INTO recordings(filename, peer_id, size, state) \ + VALUES(?, ?, 0, 'new') \ + ON CONFLICT(filename) DO UPDATE SET \ + peer_id = excluded.peer_id, \ + size = 0, state = 'new', \ + started_at = strftime('%s','now'), \ + finished_at = NULL", + ) + .bind(filename) + .bind(peer_id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn recording_set_state( + &self, + filename: &str, + state: &str, + size: Option, + finished: bool, + ) -> ResultType<()> { + if finished { + sqlx::query( + "UPDATE recordings SET state = ?, size = COALESCE(?, size), \ + finished_at = strftime('%s','now') WHERE filename = ?", + ) + .bind(state) + .bind(size) + .bind(filename) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } else { + sqlx::query( + "UPDATE recordings SET state = ?, size = COALESCE(?, size) WHERE filename = ?", + ) + .bind(state) + .bind(size) + .bind(filename) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + Ok(()) + } + + pub async fn recording_get(&self, filename: &str) -> ResultType> { + let row = sqlx::query("SELECT size, state FROM recordings WHERE filename = ?") + .bind(filename) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(|r| RecordingFile { + size: r.try_get("size").unwrap_or(0), + state: r.try_get("state").unwrap_or_default(), + })) + } + + pub async fn recording_delete(&self, filename: &str) -> ResultType<()> { + sqlx::query("DELETE FROM recordings WHERE filename = ?") + .bind(filename) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + // ----- Strategy resolver ----- + + /// Resolve the strategy for a peer. Priority order: direct peer + /// assignment > device-group assignment (via the peer's owner) > user + /// assignment. Returns the strategy with the largest `priority` within + /// the highest-priority tier. If nothing matches, returns the row's + /// `Default`, which the heartbeat handler treats as "no strategy". + pub async fn strategy_resolve_for(&self, peer_id: &str) -> ResultType { + // First try a direct peer assignment. + if let Some(s) = self + .strategy_lookup( + "SELECT s.modified_at, s.config_options_json, s.extra_json \ + FROM strategies s \ + JOIN strategy_assignments sa ON sa.strategy_id = s.id \ + WHERE sa.peer_id = ? \ + ORDER BY sa.priority DESC LIMIT 1", + &[peer_id], + ) + .await? + { + return Ok(s); + } + // Look up the device's owner; without an owner there's nothing to + // join on, so we stop here. + let owner = sqlx::query( + "SELECT user_id FROM device_sysinfo WHERE id = ? AND user_id IS NOT NULL LIMIT 1", + ) + .bind(peer_id) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + let Some(owner_row) = owner else { + return Ok(ResolvedStrategy::default()); + }; + let owner_id: i64 = owner_row.try_get("user_id")?; + let owner_id_str = owner_id.to_string(); + // Device-group assignment: any strategy assigned to a group that the + // owner is a member of. + if let Some(s) = self + .strategy_lookup( + "SELECT s.modified_at, s.config_options_json, s.extra_json \ + FROM strategies s \ + JOIN strategy_assignments sa ON sa.strategy_id = s.id \ + WHERE sa.device_group_id IN ( \ + SELECT device_group_id FROM device_group_members WHERE user_id = ? \ + ) \ + ORDER BY sa.priority DESC LIMIT 1", + &[&owner_id_str], + ) + .await? + { + return Ok(s); + } + // User assignment. + if let Some(s) = self + .strategy_lookup( + "SELECT s.modified_at, s.config_options_json, s.extra_json \ + FROM strategies s \ + JOIN strategy_assignments sa ON sa.strategy_id = s.id \ + WHERE sa.user_id = ? \ + ORDER BY sa.priority DESC LIMIT 1", + &[&owner_id_str], + ) + .await? + { + return Ok(s); + } + Ok(ResolvedStrategy::default()) + } + + async fn strategy_lookup( + &self, + sql: &str, + params: &[&str], + ) -> ResultType> { + let mut q = sqlx::query(sql); + for p in params { + q = q.bind(*p); + } + let row = q + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(|r| ResolvedStrategy { + modified_at: r.try_get("modified_at").unwrap_or(0), + config_options_json: r + .try_get::, _>("config_options_json") + .unwrap_or_default() + .unwrap_or_else(|| "{}".to_string()), + extra_json: r + .try_get::, _>("extra_json") + .unwrap_or_default() + .unwrap_or_else(|| "{}".to_string()), + })) + } + + // =================================================================== + // M4: 2FA (TOTP) + pending challenges + // =================================================================== + + /// Returns the user's TOTP secret if they have enrolled. Used by the + /// login handler to decide whether to issue a `tfa_check` challenge. + pub async fn totp_get_secret(&self, user_id: i64) -> ResultType> { + let row = + sqlx::query("SELECT secret_b32 FROM user_totp_secrets WHERE user_id = ?") + .bind(user_id) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(|r| r.try_get::("secret_b32").unwrap_or_default())) + } + + /// Idempotent — re-enrolling overwrites the existing secret. + pub async fn totp_enroll(&self, user_id: i64, secret_b32: &str) -> ResultType<()> { + sqlx::query( + "INSERT INTO user_totp_secrets(user_id, secret_b32, enrolled_at) \ + VALUES(?, ?, strftime('%s','now')) \ + ON CONFLICT(user_id) DO UPDATE SET \ + secret_b32 = excluded.secret_b32, \ + enrolled_at = strftime('%s','now')", + ) + .bind(user_id) + .bind(secret_b32) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn totp_unenroll(&self, user_id: i64) -> ResultType { + let res = sqlx::query("DELETE FROM user_totp_secrets WHERE user_id = ?") + .bind(user_id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(res.rows_affected() > 0) + } + + /// Issue a short-lived TFA-challenge nonce. The login handler returns + /// this in the `secret` field of `tfa_check`; the client echoes it back + /// alongside the TOTP code. + pub async fn tfa_challenge_create( + &self, + user_id: i64, + ttl_secs: i64, + ) -> ResultType { + let nonce = base64::encode_config( + sodiumoxide::randombytes::randombytes(24), + base64::URL_SAFE_NO_PAD, + ); + let expires_at = chrono::Utc::now().timestamp() + ttl_secs; + sqlx::query( + "INSERT INTO pending_tfa_challenges(secret, user_id, expires_at) \ + VALUES(?, ?, ?)", + ) + .bind(&nonce) + .bind(user_id) + .bind(expires_at) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(nonce) + } + + /// Look up a TFA challenge nonce. Returns the user_id if the row exists + /// and has not expired; otherwise None. Does NOT delete the row — the + /// caller deletes after the TOTP code itself has been verified, so a + /// failed code attempt does not invalidate the challenge. + pub async fn tfa_challenge_lookup(&self, nonce: &str) -> ResultType> { + let row = sqlx::query( + "SELECT user_id, expires_at FROM pending_tfa_challenges WHERE secret = ?", + ) + .bind(nonce) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + let Some(row) = row else { return Ok(None) }; + let expires_at: i64 = row.try_get("expires_at")?; + if expires_at <= chrono::Utc::now().timestamp() { + return Ok(None); + } + Ok(Some(row.try_get("user_id")?)) + } + + pub async fn tfa_challenge_consume(&self, nonce: &str) -> ResultType<()> { + sqlx::query("DELETE FROM pending_tfa_challenges WHERE secret = ?") + .bind(nonce) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + /// Replace any prior pending codes for this email with a fresh one. The + /// `code_hash` is sha256(code) — the plaintext code is mailed, never + /// persisted. + pub async fn email_code_create( + &self, + email: &str, + code_sha256: &[u8], + ttl_secs: i64, + ) -> ResultType<()> { + // Drop earlier pending codes for the same email so the latest one + // wins and we don't accumulate row clutter. + sqlx::query("DELETE FROM pending_email_codes WHERE email = ?") + .bind(email) + .execute(self.pool.get().await?.deref_mut()) + .await?; + let expires_at = chrono::Utc::now().timestamp() + ttl_secs; + sqlx::query( + "INSERT INTO pending_email_codes(email, code_hash, expires_at) \ + VALUES(?, ?, ?)", + ) + .bind(email) + .bind(code_sha256) + .bind(expires_at) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + /// Verify a code attempt for `email`. Returns: + /// - `Ok(true)` on success — the row is consumed. + /// - `Ok(false)` on bad code or no pending row — the attempts counter is + /// bumped; after 5 attempts the row is purged. + /// - `Err(_)` on DB error. + pub async fn email_code_verify( + &self, + email: &str, + code_sha256: &[u8], + ) -> ResultType { + let now = chrono::Utc::now().timestamp(); + let row = sqlx::query( + "SELECT id, code_hash, expires_at, attempts \ + FROM pending_email_codes WHERE email = ? \ + ORDER BY id DESC LIMIT 1", + ) + .bind(email) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + let Some(row) = row else { return Ok(false) }; + let id: i64 = row.try_get("id")?; + let stored: Vec = row.try_get("code_hash")?; + let expires_at: i64 = row.try_get("expires_at")?; + let attempts: i64 = row.try_get("attempts")?; + if expires_at <= now { + sqlx::query("DELETE FROM pending_email_codes WHERE id = ?") + .bind(id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + return Ok(false); + } + if stored.len() == code_sha256.len() + && constant_time_eq(&stored, code_sha256) + { + sqlx::query("DELETE FROM pending_email_codes WHERE id = ?") + .bind(id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + return Ok(true); + } + let new_attempts = attempts + 1; + if new_attempts >= 5 { + sqlx::query("DELETE FROM pending_email_codes WHERE id = ?") + .bind(id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } else { + sqlx::query("UPDATE pending_email_codes SET attempts = ? WHERE id = ?") + .bind(new_attempts) + .bind(id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + } + Ok(false) + } + + pub async fn strategy_find_by_name(&self, name: &str) -> ResultType> { + let row = sqlx::query("SELECT id FROM strategies WHERE name = ?") + .bind(name) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(|r| r.try_get::("id").unwrap_or(0))) + } + + /// Replace any existing peer-scoped assignment for this peer with the + /// new strategy. Keeps the resolver's "peer > group > user" priority + /// stable per peer. + pub async fn strategy_assign_peer( + &self, + strategy_id: i64, + peer_id: &str, + ) -> ResultType<()> { + sqlx::query("DELETE FROM strategy_assignments WHERE peer_id = ?") + .bind(peer_id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + sqlx::query( + "INSERT INTO strategy_assignments(strategy_id, peer_id, priority) \ + VALUES(?, ?, 100)", + ) + .bind(strategy_id) + .bind(peer_id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + /// Ensure `group_name` exists, create it if missing, and add `user_id` as + /// a member if not already present. + pub async fn device_group_ensure_member( + &self, + group_name: &str, + user_id: i64, + ) -> ResultType<()> { + // Upsert the group itself. + sqlx::query("INSERT OR IGNORE INTO device_groups(name) VALUES(?)") + .bind(group_name) + .execute(self.pool.get().await?.deref_mut()) + .await?; + let row = sqlx::query("SELECT id FROM device_groups WHERE name = ?") + .bind(group_name) + .fetch_one(self.pool.get().await?.deref_mut()) + .await?; + let gid: i64 = row.try_get("id")?; + sqlx::query( + "INSERT OR IGNORE INTO device_group_members(device_group_id, user_id) \ + VALUES(?, ?)", + ) + .bind(gid) + .bind(user_id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + // =================================================================== + // M4: OIDC provider config + session state + // =================================================================== + + pub async fn oidc_provider_upsert(&self, p: &OidcProviderRow) -> ResultType<()> { + sqlx::query( + "INSERT INTO oidc_providers(name, display_name, icon_url, issuer_url, \ + client_id, client_secret, scopes, redirect_url, enabled) \ + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?) \ + ON CONFLICT(name) DO UPDATE SET \ + display_name = excluded.display_name, \ + icon_url = excluded.icon_url, \ + issuer_url = excluded.issuer_url, \ + client_id = excluded.client_id, \ + client_secret = excluded.client_secret, \ + scopes = excluded.scopes, \ + redirect_url = excluded.redirect_url, \ + enabled = excluded.enabled", + ) + .bind(&p.name) + .bind(p.display_name.as_deref()) + .bind(p.icon_url.as_deref()) + .bind(&p.issuer_url) + .bind(&p.client_id) + .bind(&p.client_secret) + .bind(&p.scopes) + .bind(&p.redirect_url) + .bind(if p.enabled { 1 } else { 0 }) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn oidc_provider_get(&self, name: &str) -> ResultType> { + let row = sqlx::query( + "SELECT name, display_name, icon_url, issuer_url, client_id, client_secret, \ + scopes, redirect_url, enabled \ + FROM oidc_providers WHERE name = ? AND enabled = 1", + ) + .bind(name) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(row_to_oidc_provider)) + } + + pub async fn oidc_provider_list_enabled(&self) -> ResultType> { + let rows = sqlx::query( + "SELECT name, display_name, icon_url, issuer_url, client_id, client_secret, \ + scopes, redirect_url, enabled \ + FROM oidc_providers WHERE enabled = 1 ORDER BY name", + ) + .fetch_all(self.pool.get().await?.deref_mut()) + .await?; + Ok(rows.into_iter().map(row_to_oidc_provider).collect()) + } + + pub async fn oidc_session_create( + &self, + s: &OidcSessionInsert<'_>, + ) -> ResultType<()> { + sqlx::query( + "INSERT INTO oidc_sessions(code, provider, state, client_id_str, client_uuid, \ + device_info_json, created_at, expires_at) \ + VALUES(?, ?, ?, ?, ?, ?, strftime('%s','now'), ?)", + ) + .bind(s.code) + .bind(s.provider) + .bind(s.state) + .bind(s.client_id_str) + .bind(s.client_uuid) + .bind(s.device_info_json) + .bind(s.expires_at) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn oidc_session_get_by_code( + &self, + code: &str, + ) -> ResultType> { + let row = sqlx::query( + "SELECT code, provider, state, client_id_str, client_uuid, device_info_json, \ + expires_at, status, access_token, user_id, error \ + FROM oidc_sessions WHERE code = ?", + ) + .bind(code) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(row_to_oidc_session)) + } + + pub async fn oidc_session_get_by_state( + &self, + state: &str, + ) -> ResultType> { + let row = sqlx::query( + "SELECT code, provider, state, client_id_str, client_uuid, device_info_json, \ + expires_at, status, access_token, user_id, error \ + FROM oidc_sessions WHERE state = ?", + ) + .bind(state) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(row_to_oidc_session)) + } + + pub async fn oidc_session_complete( + &self, + code: &str, + access_token: &str, + user_id: i64, + ) -> ResultType<()> { + sqlx::query( + "UPDATE oidc_sessions SET status = 'success', access_token = ?, user_id = ? \ + WHERE code = ?", + ) + .bind(access_token) + .bind(user_id) + .bind(code) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + pub async fn oidc_session_fail(&self, code: &str, error: &str) -> ResultType<()> { + sqlx::query( + "UPDATE oidc_sessions SET status = 'error', error = ? WHERE code = ?", + ) + .bind(error) + .bind(code) + .execute(self.pool.get().await?.deref_mut()) + .await?; + Ok(()) + } + + /// Find an existing user by their OIDC `sub`, falling back to email + /// (case-insensitive). Returns `None` if neither matches; the caller + /// then decides whether to auto-provision a new user. + pub async fn user_find_by_oidc( + &self, + oidc_subject: &str, + email: Option<&str>, + ) -> ResultType> { + let row = sqlx::query( + "SELECT id, username, password_hash, display_name, email, note, avatar, status, is_admin \ + FROM users WHERE oidc_subject = ? LIMIT 1", + ) + .bind(oidc_subject) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + if let Some(r) = row { + return Ok(Some(row_to_user(r))); + } + if let Some(e) = email.filter(|s| !s.is_empty()) { + return self.user_find_by_email(e).await; + } + Ok(None) + } + + /// Create or update a user from an OIDC identity. The local username is + /// either the email (preferred) or the sub if no email. Subsequent + /// logins re-use the same row via oidc_subject. + pub async fn user_upsert_oidc( + &self, + oidc_subject: &str, + email: Option<&str>, + display_name: Option<&str>, + ) -> ResultType { + let username = email + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("oidc:{}", oidc_subject)); + if let Some(existing) = self.user_find_by_oidc(oidc_subject, email).await? { + // Make sure the oidc_subject is recorded on this row even if we + // matched by email — keeps subsequent lookups O(1). + sqlx::query( + "UPDATE users SET oidc_subject = ?, email = COALESCE(NULLIF(?, ''), email), \ + display_name = COALESCE(NULLIF(?, ''), display_name) WHERE id = ?", + ) + .bind(oidc_subject) + .bind(email.unwrap_or("")) + .bind(display_name.unwrap_or("")) + .bind(existing.id) + .execute(self.pool.get().await?.deref_mut()) + .await?; + return Ok(self + .user_find_by_id(existing.id) + .await? + .unwrap_or(existing)); + } + // New user. Empty password_hash blocks password login until the + // operator (or the user) sets one. + sqlx::query( + "INSERT INTO users(username, password_hash, display_name, email, status, is_admin, oidc_subject) \ + VALUES(?, '', ?, ?, 1, 0, ?)", + ) + .bind(&username) + .bind(display_name.unwrap_or("")) + .bind(email.unwrap_or("")) + .bind(oidc_subject) + .execute(self.pool.get().await?.deref_mut()) + .await?; + self.user_find_by_username(&username) + .await? + .ok_or_else(|| hbb_common::anyhow::anyhow!("post-insert lookup failed")) + } + + pub async fn user_find_by_email(&self, email: &str) -> ResultType> { + let row = sqlx::query( + "SELECT id, username, password_hash, display_name, email, note, avatar, status, is_admin \ + FROM users WHERE email = ? COLLATE NOCASE LIMIT 1", + ) + .bind(email) + .fetch_optional(self.pool.get().await?.deref_mut()) + .await?; + Ok(row.map(row_to_user)) + } + + /// Read all queued heartbeat commands for `peer_id` and delete them in + /// the same transaction. Each command is read at most once. + pub async fn heartbeat_pop_commands( + &self, + peer_id: &str, + ) -> ResultType> { + let mut guard = self.pool.get().await?; + let conn: &mut SqliteConnection = guard.deref_mut(); + let mut tx = conn.begin().await?; + let rows = sqlx::query( + "SELECT kind, payload FROM heartbeat_commands WHERE peer_id = ?", + ) + .bind(peer_id) + .fetch_all(&mut tx) + .await?; + if rows.is_empty() { + return Ok(vec![]); + } + sqlx::query("DELETE FROM heartbeat_commands WHERE peer_id = ?") + .bind(peer_id) + .execute(&mut tx) + .await?; + tx.commit().await?; + Ok(rows + .into_iter() + .map(|r| HeartbeatCommand { + kind: r.try_get("kind").unwrap_or_default(), + payload: r.try_get::, _>("payload").unwrap_or_default(), + }) + .collect()) + } + pub async fn get_peer(&self, id: &str) -> ResultType> { Ok(sqlx::query_as!( Peer, @@ -144,6 +2019,338 @@ impl Database { } } +/// Timing-safe equality for hash comparisons. Slightly paranoid given the +/// codes are short-lived, but cheap. +fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut diff: u8 = 0; + for (x, y) in a.iter().zip(b.iter()) { + diff |= x ^ y; + } + diff == 0 +} + +fn row_to_user(row: sqlx::sqlite::SqliteRow) -> UserRow { + let is_admin: i64 = row.try_get("is_admin").unwrap_or(0); + UserRow { + id: row.try_get("id").unwrap_or(0), + username: row.try_get("username").unwrap_or_default(), + password_hash: row.try_get("password_hash").unwrap_or_default(), + display_name: row.try_get("display_name").unwrap_or_default(), + email: row.try_get("email").unwrap_or_default(), + note: row.try_get("note").unwrap_or_default(), + avatar: row.try_get("avatar").unwrap_or_default(), + status: row.try_get("status").unwrap_or(1), + is_admin: is_admin != 0, + } +} + +fn row_to_oidc_provider(row: sqlx::sqlite::SqliteRow) -> OidcProviderRow { + let enabled: i64 = row.try_get("enabled").unwrap_or(0); + OidcProviderRow { + name: row.try_get("name").unwrap_or_default(), + display_name: row + .try_get::, _>("display_name") + .ok() + .flatten(), + icon_url: row.try_get::, _>("icon_url").ok().flatten(), + issuer_url: row.try_get("issuer_url").unwrap_or_default(), + client_id: row.try_get("client_id").unwrap_or_default(), + client_secret: row.try_get("client_secret").unwrap_or_default(), + scopes: row + .try_get("scopes") + .unwrap_or_else(|_| "openid email profile".to_string()), + redirect_url: row.try_get("redirect_url").unwrap_or_default(), + enabled: enabled != 0, + } +} + +fn row_to_oidc_session(row: sqlx::sqlite::SqliteRow) -> OidcSessionRow { + OidcSessionRow { + code: row.try_get("code").unwrap_or_default(), + provider: row.try_get("provider").unwrap_or_default(), + state: row.try_get("state").unwrap_or_default(), + client_id_str: row.try_get("client_id_str").unwrap_or_default(), + client_uuid: row.try_get("client_uuid").unwrap_or_default(), + device_info_json: row + .try_get("device_info_json") + .unwrap_or_else(|_| "{}".to_string()), + expires_at: row.try_get("expires_at").unwrap_or(0), + status: row + .try_get("status") + .unwrap_or_else(|_| "pending".to_string()), + access_token: row.try_get::, _>("access_token").ok().flatten(), + user_id: row.try_get::, _>("user_id").ok().flatten(), + error: row.try_get::, _>("error").ok().flatten(), + } +} + +const M1_SCHEMA: &[&str] = &[ + "CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + display_name TEXT NOT NULL DEFAULT '', + email TEXT NOT NULL DEFAULT '', + note TEXT NOT NULL DEFAULT '', + avatar TEXT NOT NULL DEFAULT '', + status INTEGER NOT NULL DEFAULT 1, + is_admin INTEGER NOT NULL DEFAULT 0, + created_at DATETIME NOT NULL DEFAULT(current_timestamp), + updated_at DATETIME NOT NULL DEFAULT(current_timestamp) + )", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)", + "CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token_sha256 BLOB NOT NULL UNIQUE, + peer_id TEXT NOT NULL DEFAULT '', + peer_uuid TEXT NOT NULL DEFAULT '', + device_info TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL DEFAULT(current_timestamp), + last_used_at DATETIME NOT NULL DEFAULT(current_timestamp), + expires_at DATETIME NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE + )", + "CREATE INDEX IF NOT EXISTS idx_tokens_user ON tokens(user_id)", + "CREATE INDEX IF NOT EXISTS idx_tokens_expires ON tokens(expires_at)", + "CREATE TABLE IF NOT EXISTS device_sysinfo ( + id TEXT NOT NULL, + uuid TEXT NOT NULL, + version INTEGER NOT NULL DEFAULT 0, + last_seen_at DATETIME NOT NULL DEFAULT(current_timestamp), + last_heartbeat_at DATETIME NOT NULL DEFAULT(current_timestamp), + conns TEXT NOT NULL DEFAULT '[]', + payload TEXT NOT NULL DEFAULT '{}', + sysinfo_ver_seen TEXT NOT NULL DEFAULT '', + updated_at DATETIME NOT NULL DEFAULT(current_timestamp), + PRIMARY KEY (id, uuid) + )", + "CREATE INDEX IF NOT EXISTS idx_device_sysinfo_lastseen ON device_sysinfo(last_seen_at)", +]; + +const M2_SCHEMA: &[&str] = &[ + "CREATE TABLE IF NOT EXISTS address_books ( + guid TEXT PRIMARY KEY, + owner_user_id INTEGER NOT NULL, + name TEXT NOT NULL, + note TEXT, + kind INTEGER NOT NULL, + info_json TEXT, + created_at INTEGER NOT NULL DEFAULT (strftime('%s','now')) + )", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_ab_owner_kind_name \ + ON address_books(owner_user_id, kind, name)", + "CREATE INDEX IF NOT EXISTS idx_ab_owner ON address_books(owner_user_id)", + "CREATE TABLE IF NOT EXISTS device_groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + note TEXT, + created_at INTEGER NOT NULL DEFAULT (strftime('%s','now')) + )", + "CREATE TABLE IF NOT EXISTS device_group_members ( + device_group_id INTEGER NOT NULL, + user_id INTEGER, + PRIMARY KEY (device_group_id, user_id) + )", + "CREATE INDEX IF NOT EXISTS idx_dgm_user ON device_group_members(user_id)", + // SQLite forbids expressions in PRIMARY KEY constraints, so we use a + // unique index over the COALESCEd tuple to enforce one share per + // (ab, user) and one per (ab, group). NULLs collapse to 0 so two NULL + // user_ids on the same ab + group still conflict. + "CREATE TABLE IF NOT EXISTS address_book_shares ( + ab_guid TEXT NOT NULL, + user_id INTEGER, + group_id INTEGER, + rule INTEGER NOT NULL + )", + "CREATE UNIQUE INDEX IF NOT EXISTS uq_abshare \ + ON address_book_shares(ab_guid, COALESCE(user_id,0), COALESCE(group_id,0))", + "CREATE INDEX IF NOT EXISTS idx_abshare_user ON address_book_shares(user_id)", + "CREATE INDEX IF NOT EXISTS idx_abshare_group ON address_book_shares(group_id)", + "CREATE TABLE IF NOT EXISTS address_book_peers ( + ab_guid TEXT NOT NULL, + peer_id TEXT NOT NULL, + alias TEXT, + note TEXT, + password TEXT, + hash TEXT, + username TEXT, + hostname TEXT, + platform TEXT, + updated_at INTEGER NOT NULL DEFAULT (strftime('%s','now')), + PRIMARY KEY (ab_guid, peer_id) + )", + "CREATE TABLE IF NOT EXISTS address_book_tags ( + ab_guid TEXT NOT NULL, + name TEXT NOT NULL, + color INTEGER NOT NULL, + PRIMARY KEY (ab_guid, name) + )", + "CREATE TABLE IF NOT EXISTS address_book_peer_tags ( + ab_guid TEXT NOT NULL, + peer_id TEXT NOT NULL, + tag_name TEXT NOT NULL, + PRIMARY KEY (ab_guid, peer_id, tag_name) + )", + "CREATE INDEX IF NOT EXISTS idx_abpt_peer ON address_book_peer_tags(ab_guid, peer_id)", + "CREATE INDEX IF NOT EXISTS idx_abpt_tag ON address_book_peer_tags(ab_guid, tag_name)", +]; + +const M2_SOFT_ALTERS: &[&str] = &[ + // Bind a device to its enrolled user. Filled by the login handler when + // the client passes id+uuid in the body. + "ALTER TABLE device_sysinfo ADD COLUMN user_id INTEGER", + // OIDC `sub` claim, used to map an IdP identity to a local user across + // sessions. Nullable so password-only users keep working. + "ALTER TABLE users ADD COLUMN oidc_subject TEXT", +]; + +const M3_SCHEMA: &[&str] = &[ + // Audit conn rows are keyed by an opaque guid that we hand back to the + // client so the operator's end-of-session note dialog can attach a note + // to the right session. + "CREATE TABLE IF NOT EXISTS audit_conn ( + guid TEXT PRIMARY KEY, + peer_id TEXT NOT NULL, + remote_id TEXT, + conn_id INTEGER NOT NULL DEFAULT 0, + session_id INTEGER NOT NULL DEFAULT 0, + ip TEXT, + action TEXT NOT NULL, + note TEXT, + started_at INTEGER NOT NULL, + ended_at INTEGER + )", + "CREATE INDEX IF NOT EXISTS idx_audit_conn_peer ON audit_conn(peer_id, started_at)", + "CREATE TABLE IF NOT EXISTS audit_file ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + peer_id TEXT NOT NULL, + remote_peer TEXT, + direction INTEGER NOT NULL, + path TEXT NOT NULL, + is_file INTEGER NOT NULL, + info_json TEXT NOT NULL, + at INTEGER NOT NULL DEFAULT (strftime('%s','now')) + )", + "CREATE INDEX IF NOT EXISTS idx_audit_file_peer ON audit_file(peer_id, at)", + "CREATE TABLE IF NOT EXISTS audit_alarm ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + peer_id TEXT NOT NULL, + typ INTEGER NOT NULL, + info_json TEXT NOT NULL, + at INTEGER NOT NULL DEFAULT (strftime('%s','now')) + )", + "CREATE INDEX IF NOT EXISTS idx_audit_alarm_peer ON audit_alarm(peer_id, at)", + "CREATE TABLE IF NOT EXISTS recordings ( + filename TEXT PRIMARY KEY, + peer_id TEXT NOT NULL, + size INTEGER NOT NULL DEFAULT 0, + state TEXT NOT NULL, + started_at INTEGER NOT NULL DEFAULT (strftime('%s','now')), + finished_at INTEGER + )", + "CREATE INDEX IF NOT EXISTS idx_recordings_peer ON recordings(peer_id, started_at)", + "CREATE TABLE IF NOT EXISTS strategies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + modified_at INTEGER NOT NULL DEFAULT 0, + config_options_json TEXT NOT NULL DEFAULT '{}', + extra_json TEXT NOT NULL DEFAULT '{}' + )", + // strategy_assignments: exactly one of (user_id, device_group_id, peer_id) + // is non-null per row. Resolution priority is encoded by `priority` + // (higher wins on ties within the same scope). + "CREATE TABLE IF NOT EXISTS strategy_assignments ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + strategy_id INTEGER NOT NULL, + user_id INTEGER, + device_group_id INTEGER, + peer_id TEXT, + priority INTEGER NOT NULL DEFAULT 0 + )", + "CREATE INDEX IF NOT EXISTS idx_strategy_assign_peer ON strategy_assignments(peer_id)", + "CREATE INDEX IF NOT EXISTS idx_strategy_assign_user ON strategy_assignments(user_id)", + "CREATE INDEX IF NOT EXISTS idx_strategy_assign_group ON strategy_assignments(device_group_id)", + // heartbeat_commands: one-shot commands the next /api/heartbeat + // response delivers, then deletes. `kind` is one of 'disconnect' + // (payload = JSON array of conn_ids) or 'sysinfo' (payload null). + "CREATE TABLE IF NOT EXISTS heartbeat_commands ( + peer_id TEXT NOT NULL, + kind TEXT NOT NULL, + payload TEXT, + created_at INTEGER NOT NULL DEFAULT (strftime('%s','now')), + PRIMARY KEY (peer_id, kind) + )", +]; + +const M4_SCHEMA: &[&str] = &[ + // TOTP enrollment per user. The shared secret is stored base32 to match + // how authenticator apps encode it; the operator scans/enters it. + "CREATE TABLE IF NOT EXISTS user_totp_secrets ( + user_id INTEGER PRIMARY KEY, + secret_b32 TEXT NOT NULL, + enrolled_at INTEGER NOT NULL, + recovery_codes_json TEXT + )", + // Short-lived nonce echoed back by the client during the second login + // POST. Login flow: password verified -> tfa_check{secret=} -> + // client sends tfa_code{secret=, tfaCode=<6 digits>}. + "CREATE TABLE IF NOT EXISTS pending_tfa_challenges ( + secret TEXT PRIMARY KEY, + user_id INTEGER NOT NULL, + expires_at INTEGER NOT NULL + )", + "CREATE INDEX IF NOT EXISTS idx_pending_tfa_user ON pending_tfa_challenges(user_id)", + // Pending email-login codes. Hashed at rest so a DB leak doesn't + // immediately give an attacker a working code; bcrypt would be overkill + // for a 6-digit secret with a 10-minute TTL — sha256 is enough. + "CREATE TABLE IF NOT EXISTS pending_email_codes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT NOT NULL, + code_hash BLOB NOT NULL, + expires_at INTEGER NOT NULL, + attempts INTEGER NOT NULL DEFAULT 0, + created_at INTEGER NOT NULL DEFAULT (strftime('%s','now')) + )", + "CREATE INDEX IF NOT EXISTS idx_pending_email_email ON pending_email_codes(email)", + "CREATE INDEX IF NOT EXISTS idx_pending_email_expires ON pending_email_codes(expires_at)", + // OIDC providers and in-flight device-flow sessions. Providers are + // upserted at startup from the operator-supplied --oidc-config TOML + // (or hand-inserted via SQL). + "CREATE TABLE IF NOT EXISTS oidc_providers ( + name TEXT PRIMARY KEY, + display_name TEXT, + icon_url TEXT, + issuer_url TEXT NOT NULL, + client_id TEXT NOT NULL, + client_secret TEXT NOT NULL, + scopes TEXT NOT NULL DEFAULT 'openid email profile', + redirect_url TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1 + )", + // `code` is the opaque handle the client polls with; `state` is the + // CSRF token round-tripped through the IdP. Status transitions: + // pending -> success | error. + "CREATE TABLE IF NOT EXISTS oidc_sessions ( + code TEXT PRIMARY KEY, + provider TEXT NOT NULL, + state TEXT NOT NULL UNIQUE, + client_id_str TEXT NOT NULL DEFAULT '', + client_uuid TEXT NOT NULL DEFAULT '', + device_info_json TEXT NOT NULL DEFAULT '{}', + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + access_token TEXT, + user_id INTEGER, + error TEXT + )", + "CREATE INDEX IF NOT EXISTS idx_oidc_sessions_status ON oidc_sessions(status, expires_at)", +]; + #[cfg(test)] mod tests { use hbb_common::tokio; diff --git a/src/lib.rs b/src/lib.rs index 8da29a2..70ae58c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod rendezvous_server; pub use rendezvous_server::*; +pub mod api; pub mod common; -mod database; +pub mod database; mod peer; mod version; diff --git a/src/main.rs b/src/main.rs index 8a0ff24..0a945b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,22 @@ fn main() -> ResultType<()> { -u, --software-url=[URL] 'Sets download url of RustDesk software of newest version' -r, --relay-servers=[HOST] 'Sets the default relay servers, separated by comma' -M, --rmem=[NUMBER(default={RMEM})] 'Sets UDP recv buffer size, set system rmem_max first, e.g., sudo sysctl -w net.core.rmem_max=52428800. vi /etc/sysctl.conf, net.core.rmem_max=52428800, sudo sysctl –p' + --http-port=[NUMBER(default=21114)] 'HTTP management API port (0 disables)' + --bootstrap-admin-username=[USERNAME] 'Username to seed on first startup if users table is empty' + --bootstrap-admin-password=[PASSWORD] 'Password to seed on first startup if users table is empty' + --ab-legacy-mode=[on|off] 'When on, /api/ab/personal returns 404 to force legacy single-blob AB' + --ab-max-peers-per-book=[NUMBER(default=100)] 'Surfaced via /api/ab/settings.max_peer_one_ab' + --recording-dir=[PATH(default=./recordings)] 'Root directory for /api/record uploads' + --recording-max-size-mb=[NUMBER] 'Optional ceiling per recording file; 0 or unset = unlimited' + --audit-retention-days=[NUMBER] 'Hourly task deletes audit rows older than N days; 0 disables' + --smtp-host=[HOST] 'SMTP host for email-code login; if empty, codes are logged to stdout (dev mode)' + --smtp-port=[NUMBER(default=587)] 'SMTP port' + --smtp-user=[USER] 'SMTP username (omit for unauthenticated relays)' + --smtp-pass=[PASS] 'SMTP password' + --smtp-from=[ADDR] 'From: address for outbound login emails (default: noreply@)' + --smtp-tls=[on|off] 'STARTTLS on the SMTP connection (default: on)' + --public-base-url=[URL] 'Externally reachable HTTP base URL (e.g. https://rustdesk.example.com:21114) — required for OIDC redirect callbacks' + --oidc-config=[PATH] 'TOML file describing OIDC providers (upserted into oidc_providers at startup)' , --mask=[MASK] 'Determine if the connection comes from LAN, e.g. 192.168.0.0/16' -k, --key=[KEY] 'Only allow the client with the same key'", ); @@ -31,7 +47,16 @@ fn main() -> ResultType<()> { } let rmem = get_arg("rmem").parse::().unwrap_or(RMEM); let serial: i32 = get_arg("serial").parse().unwrap_or(0); + let http_port: i32 = get_arg_or("http-port", "21114".to_string()) + .parse() + .unwrap_or(21114); crate::common::check_software_update(); - RendezvousServer::start(port, serial, &get_arg_or("key", "-".to_owned()), rmem)?; + RendezvousServer::start( + port, + serial, + &get_arg_or("key", "-".to_owned()), + rmem, + http_port, + )?; Ok(()) } diff --git a/src/rendezvous_server.rs b/src/rendezvous_server.rs index ff68441..83cadc4 100644 --- a/src/rendezvous_server.rs +++ b/src/rendezvous_server.rs @@ -8,7 +8,7 @@ use hbb_common::{ futures::future::join_all, futures_util::{ sink::SinkExt, - stream::{SplitSink, StreamExt}, + stream::{SplitSink, SplitStream, StreamExt}, }, log, protobuf::{Message as _, MessageField}, @@ -16,7 +16,7 @@ use hbb_common::{ register_pk_response::Result::{TOO_FREQUENT, UUID_MISMATCH}, *, }, - tcp::{listen_any, FramedStream}, + tcp::{listen_any, Encrypt, FramedStream}, timeout, tokio::{ self, @@ -31,7 +31,7 @@ use hbb_common::{ AddrMangle, ResultType, }; use ipnetwork::Ipv4Network; -use sodiumoxide::crypto::sign; +use sodiumoxide::crypto::{box_, sign}; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -49,9 +49,14 @@ enum Data { const REG_TIMEOUT: i32 = 30_000; type TcpStreamSink = SplitSink, Bytes>; +type TcpStreamSrc = SplitStream>; type WsSink = SplitSink, tungstenite::Message>; enum Sink { - TcpStream(TcpStreamSink), + /// Plain or encrypted TCP. The optional `Encrypt` is only present after a + /// successful server-initiated `secure_tcp` handshake — see + /// `try_secure_tcp_handshake`. When `Some`, every outgoing message is + /// sealed with secretbox before being framed. + TcpStream(TcpStreamSink, Option), Ws(WsSink), } type Sender = mpsc::UnboundedSender; @@ -99,11 +104,56 @@ enum LoopFailure { impl RendezvousServer { #[tokio::main(flavor = "multi_thread")] - pub async fn start(port: i32, serial: i32, key: &str, rmem: usize) -> ResultType<()> { + pub async fn start( + port: i32, + serial: i32, + key: &str, + rmem: usize, + http_port: i32, + ) -> ResultType<()> { let (key, sk) = Self::get_server_sk(key); let nat_port = port - 1; let ws_port = port + 2; let pm = PeerMap::new().await?; + // M1: build the HTTP API state and seed the admin user if requested. + // Done here (right after PeerMap::new) so the API server, the seeding, + // and the rendezvous loop all share the same Database connection pool. + let api_state = crate::api::AppState::new(pm.db.clone()); + // M4: hand the same Ed25519 secret used for the rendezvous key + // exchange to the plugin-signing handler. Without this set, + // POST /lic/web/api/plugin-sign returns "plugin signing not configured". + if let Some(sk_ref) = sk.clone() { + crate::api::plugin_sign::set_signing_key(sk_ref); + } + // M4: load operator-supplied OIDC providers from --oidc-config (TOML). + // Errors are logged but don't kill the server — the operator can + // hand-insert into oidc_providers as a fallback. + let oidc_path = get_arg("oidc-config"); + if !oidc_path.is_empty() { + let public_base = api_state.cfg.public_base_url.clone(); + let db = pm.db.clone(); + match crate::api::oidc::providers::load_from_file( + &db, + std::path::Path::new(&oidc_path), + &public_base, + ) + .await + { + Ok(n) => log::info!("oidc: loaded {} providers from {}", n, oidc_path), + Err(e) => log::warn!("oidc: failed to load {}: {}", oidc_path, e), + } + } + { + let bn = get_arg("bootstrap-admin-username"); + let bp = get_arg("bootstrap-admin-password"); + if !bn.is_empty() && !bp.is_empty() { + if let Err(e) = pm.db.bootstrap_admin(&bn, &bp).await { + log::warn!("bootstrap admin failed: {}", e); + } + } else { + pm.db.warn_if_no_users().await; + } + } log::info!("serial={}", serial); let rendezvous_servers = get_servers(&get_arg("rendezvous-servers"), "rendezvous-servers"); log::info!("Listening on tcp/udp :{}", port); @@ -222,9 +272,23 @@ impl RendezvousServer { } }; let listen_signal = listen_signal(); + // The HTTP API task. `pending()` keeps the select! arm well-typed + // when the operator disabled it via `--http-port=0` — that branch + // never fires. + let api_task: std::pin::Pin< + Box> + Send>, + > = if http_port > 0 { + let addr: SocketAddr = format!("0.0.0.0:{http_port}").parse()?; + let st = api_state.clone(); + Box::pin(crate::api::serve(addr, st)) + } else { + log::info!("HTTP API disabled (http-port = 0)"); + Box::pin(std::future::pending::>()) + }; tokio::select!( res = main_task => res, res = listen_signal => res, + res = api_task => res, ) } @@ -831,7 +895,12 @@ impl RendezvousServer { if let Some(sink) = sink.as_mut() { if let Ok(bytes) = msg.write_to_bytes() { match sink { - Sink::TcpStream(s) => { + Sink::TcpStream(s, enc) => { + let bytes = if let Some(enc) = enc.as_mut() { + enc.enc(&bytes) + } else { + bytes + }; allow_err!(s.send(Bytes::from(bytes)).await); } Sink::Ws(ws) => { @@ -1185,9 +1254,70 @@ impl RendezvousServer { } } } else { - let (a, mut b) = Framed::new(stream, BytesCodec::new()).split(); - sink = Some(Sink::TcpStream(a)); - while let Ok(Some(Ok(bytes))) = timeout(30_000, b.next()).await { + let (mut a, mut b) = Framed::new(stream, BytesCodec::new()).split(); + // Server-initiated secure_tcp handshake. Only attempted when the + // server has a signing key (the default — `--key=-` auto-generates + // one). Signs an ephemeral box public key and sends it to the + // client; the client may either reply with a sealed symmetric key + // (the secure path used by logged-in clients, see + // src/client.rs:427-431 and src/common.rs:1939) or send a regular + // protobuf message (plain mode). Plain-mode clients filter out + // unsolicited KeyExchange via get_next_nonkeyexchange_msg, so the + // KeyExchange we just emitted is harmless to them. + let mut decrypter: Option = None; + let mut buffered_first: Option = None; + if let Some(sk) = self.inner.sk.clone() { + log::info!("secure_tcp: handshake starting for {}", addr); + match try_secure_tcp_handshake(&mut a, &mut b, &sk).await { + Ok(HandshakeOutcome::Secure(enc)) => { + let send_state = enc.clone(); + decrypter = Some(enc); + log::info!("secure_tcp: handshake completed (encrypted) for {}", addr); + sink = Some(Sink::TcpStream(a, Some(send_state))); + } + Ok(HandshakeOutcome::Plain(bytes)) => { + log::info!( + "secure_tcp: client sent plain first message ({} bytes) from {}", + bytes.len(), + addr + ); + buffered_first = Some(bytes); + sink = Some(Sink::TcpStream(a, None)); + } + Ok(HandshakeOutcome::Skip) => { + log::info!( + "secure_tcp: handshake window timed out (client never replied) for {}", + addr + ); + sink = Some(Sink::TcpStream(a, None)); + } + Err(e) => { + log::warn!("secure_tcp: handshake error for {}: {}", addr, e); + sink = Some(Sink::TcpStream(a, None)); + } + } + } else { + log::debug!("secure_tcp: no signing key configured; skipping handshake"); + sink = Some(Sink::TcpStream(a, None)); + } + // Replay the message we already consumed during the handshake + // window before entering the normal read loop. + if let Some(bytes) = buffered_first { + if !self.handle_tcp(&bytes, &mut sink, addr, key, ws).await { + if sink.is_none() { + self.tcp_punch.lock().await.remove(&try_into_v4(addr)); + } + log::debug!("Tcp connection from {:?} closed", addr); + return Ok(()); + } + } + while let Ok(Some(Ok(mut bytes))) = timeout(30_000, b.next()).await { + if let Some(dec) = decrypter.as_mut() { + if let Err(e) = dec.dec(&mut bytes) { + log::warn!("decryption error from {}: {}", addr, e); + break; + } + } if !self.handle_tcp(&bytes, &mut sink, addr, key, ws).await { break; } @@ -1369,3 +1499,85 @@ async fn create_tcp_listener(port: i32) -> ResultType { log::debug!("listen on tcp {:?}", s.local_addr()); Ok(s) } + +/// Outcome of the server-initiated `secure_tcp` handshake on a fresh TCP +/// rendezvous connection. The matching client code lives in +/// /Users/sn0/Desktop/rustdesk/src/common.rs:1939 (`secure_tcp_impl`). +enum HandshakeOutcome { + /// Client cooperated; the resulting `Encrypt` is shared between the + /// inbound decrypter and the outbound `Sink`. + Secure(Encrypt), + /// Client did not opt into encryption — first message we read is a + /// regular `RendezvousMessage`. We hand the bytes back to the caller so + /// they can be dispatched via `handle_tcp` before the read loop begins. + Plain(BytesMut), + /// No first message arrived within the handshake window. Fall through + /// to plain mode; the next `b.next()` in the main read loop will pick + /// up whatever the client eventually sends. + Skip, +} + +/// Server-side counterpart to the client's `secure_tcp_impl`. Sends a signed +/// ephemeral box public key, then reads the first message: +/// +/// 1. If it's a `KeyExchange` carrying `[client_box_pk, sealed_sym_key]`, +/// decrypt the sealed sym key with our box secret and return an `Encrypt` +/// initialised from that key — ready to use on both directions. +/// 2. If it's any other `RendezvousMessage`, return the bytes verbatim so +/// the caller can dispatch them as if no handshake had happened. +/// +/// Plain-mode clients (no API token configured) skip unsolicited +/// `KeyExchange` via `get_next_nonkeyexchange_msg` on their side, so the +/// `KeyExchange` we emit unconditionally is ignored when the client hasn't +/// opted into encryption. +async fn try_secure_tcp_handshake( + sink: &mut TcpStreamSink, + src: &mut TcpStreamSrc, + sk: &sign::SecretKey, +) -> ResultType { + // Ephemeral Curve25519 keypair for this connection only. + let (our_pk_b, our_sk_b) = box_::gen_keypair(); + // Sign the public key with our long-lived Ed25519 sign key. The client + // verifies this signature using the public key the user pasted into + // their RustDesk settings. + let signed = sign::sign(&our_pk_b.0, sk); + let mut msg_out = RendezvousMessage::new(); + msg_out.set_key_exchange(KeyExchange { + keys: vec![Bytes::from(signed)], + ..Default::default() + }); + let bytes = msg_out.write_to_bytes()?; + log::info!("secure_tcp: sending KeyExchange ({} bytes payload)", bytes.len()); + sink.send(Bytes::from(bytes)).await?; + + // Wait briefly for the client's reply. 5 s is comfortably below the + // client's READ_TIMEOUT and the server-loop 30 s timeout, so a slow + // plain-mode client just falls through to `Skip`. + match timeout(5_000, src.next()).await { + Ok(Some(Ok(bytes))) => { + log::info!("secure_tcp: received reply ({} bytes)", bytes.len()); + if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { + if let Some(rendezvous_message::Union::KeyExchange(ex)) = msg_in.union { + if ex.keys.len() != 2 { + bail!( + "invalid key exchange response: keys.len() = {}", + ex.keys.len() + ); + } + let key = Encrypt::decode(&ex.keys[1], &ex.keys[0], &our_sk_b)?; + return Ok(HandshakeOutcome::Secure(Encrypt::new(key))); + } else { + log::info!( + "secure_tcp: reply was a non-KeyExchange RendezvousMessage; treating as plain" + ); + } + } else { + log::info!("secure_tcp: reply did not parse as RendezvousMessage; treating as plain"); + } + Ok(HandshakeOutcome::Plain(bytes)) + } + Ok(Some(Err(e))) => bail!("read error during handshake: {}", e), + Ok(None) => bail!("connection closed during handshake"), + Err(_) => Ok(HandshakeOutcome::Skip), + } +}